Browse Source

refactor: update python ffi

master
Dnomd343 1 month ago
parent
commit
9f4bfae608
  1. 5
      src/core_ffi/python_ffi/CMakeLists.txt
  2. 32
      src/core_ffi/python_ffi/binder/common_code.cc
  3. 14
      src/core_ffi/python_ffi/binder/fast_cal.cc
  4. 14
      src/core_ffi/python_ffi/binder/group.cc
  5. 16
      src/core_ffi/python_ffi/binder/group_union.cc
  6. 15
      src/core_ffi/python_ffi/binder/klsk_cases.cc
  7. 72
      src/core_ffi/python_ffi/binder/klsk_code.cc
  8. 28
      src/core_ffi/python_ffi/binder/short_code.cc
  9. 21
      src/core_ffi/python_ffi/include/binder.h
  10. 5
      src/core_ffi/python_ffi/include/py_ffi/cases.h
  11. 34
      src/core_ffi/python_ffi/include/py_ffi/common_code.h
  12. 4
      src/core_ffi/python_ffi/include/py_ffi/fast_cal.h
  13. 22
      src/core_ffi/python_ffi/include/py_ffi/short_code.h
  14. 25
      src/core_ffi/python_ffi/klotski.cc
  15. 19
      src/core_ffi/python_ffi/wrapper/common_codec.cc
  16. 3
      src/core_ffi/python_ffi/wrapper/group_union.cc
  17. 3
      src/core_ffi/python_ffi/wrapper/short_code.cc

5
src/core_ffi/python_ffi/CMakeLists.txt

@ -6,9 +6,8 @@ set(CMAKE_CXX_STANDARD 23)
set(KLSK_PYTHON_FFI_SRC
klotski.cc
binder/cases.cc
binder/common_code.cc
binder/short_code.cc
binder/klsk_cases.cc
binder/klsk_code.cc
binder/group_union.cc
binder/group.cc
binder/fast_cal.cc

32
src/core_ffi/python_ffi/binder/common_code.cc

@ -1,32 +0,0 @@
#include "binder.h"
void bind_common_code(const py::module_ &m) {
py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>())
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
.def(py::hash(py::self))
.def("__str__", &PyCommonCode::str)
.def("__int__", &PyCommonCode::value)
.def("__repr__", &PyCommonCode::repr)
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint64_t())
.def(py::self < uint64_t()).def(py::self <= uint64_t())
.def(py::self > uint64_t()).def(py::self >= uint64_t())
// .def_property_readonly("str", &PyCommonCode::string)
.def_property_readonly("value", &PyCommonCode::value)
.def_property_readonly("short_code", &PyCommonCode::short_code)
.def("next_cases", &PyCommonCode::next_cases)
.def("to_string", &PyCommonCode::string, py::arg("shorten") = false)
.def_static("check", static_cast<bool (*)(uint64_t)>(&PyCommonCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyCommonCode::check));
}

14
src/core_ffi/python_ffi/binder/fast_cal.cc

@ -2,6 +2,20 @@
#include "binder.h"
#include "py_ffi/cases.h"
#include "py_ffi/group.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
#include "py_ffi/fast_cal.h"
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
using klotski::ffi::PyGroupUnion;
using klotski::ffi::PyGroup;
using klotski::ffi::PyFastCal;
void bind_fast_cal(const py::module_ &m) {
py::class_<PyFastCal>(m, "FastCal")
.def(py::init<PyCommonCode>())

14
src/core_ffi/python_ffi/binder/group.cc

@ -1,5 +1,19 @@
#include "binder.h"
#include "py_ffi/cases.h"
#include "py_ffi/group.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
#include "py_ffi/fast_cal.h"
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
using klotski::ffi::PyGroupUnion;
using klotski::ffi::PyGroup;
using klotski::ffi::PyFastCal;
void bind_group(const py::module_ &m) {
py::class_<PyGroup>(m, "Group")
.def_property_readonly("type_id", &PyGroup::type_id)

16
src/core_ffi/python_ffi/binder/group_union.cc

@ -1,5 +1,21 @@
#include "binder.h"
#include "py_ffi/cases.h"
#include "py_ffi/group.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
#include "py_ffi/fast_cal.h"
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
using klotski::ffi::PyGroupUnion;
using klotski::ffi::PyGroup;
using klotski::ffi::PyFastCal;
// TODO: move `bind_group` here
void bind_group_union(const py::module_ &m) {
py::class_<PyGroupUnion>(m, "GroupUnion")
.def(py::init<uint8_t>())

15
src/core_ffi/python_ffi/binder/cases.cc → src/core_ffi/python_ffi/binder/klsk_cases.cc

@ -1,6 +1,12 @@
#include <pybind11/operators.h>
#include "binder.h"
#include "py_ffi/cases.h"
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
void bind_cases(const py::module_ &mod) {
static void bind_cases(const py::module_ &mod) {
py::class_<PyCases>(mod, "Cases")
.def(py::self == py::self)
.def("__len__", &PyCases::size)
@ -10,7 +16,14 @@ void bind_cases(const py::module_ &mod) {
.def_property_readonly_static("all_cases", [](const py::object&) {
return PyCases::all_cases();
});
}
static void bind_cases_iter(const py::module_ &mod) {
py::class_<PyCasesIter>(mod, "CasesIter")
.def("__next__", &PyCasesIter::next);
}
void bind_klsk_cases(const py::module_ &mod) {
bind_cases(mod);
bind_cases_iter(mod);
}

72
src/core_ffi/python_ffi/binder/klsk_code.cc

@ -0,0 +1,72 @@
#include <pybind11/operators.h>
#include "binder.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
static void bind_common_code(const py::module_ &mod) {
py::class_<PyCommonCode>(mod, "Code")
.def(py::init<uint64_t>())
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint64_t())
.def(py::self < uint64_t()).def(py::self <= uint64_t())
.def(py::self > uint64_t()).def(py::self >= uint64_t())
.def(py::hash(py::self))
.def("__str__", &PyCommonCode::str)
.def("__int__", &PyCommonCode::value)
.def("__repr__", &PyCommonCode::repr)
.def("next_cases", &PyCommonCode::next_cases)
// TODO: add fast_cal / fast_cal_multi / ...
.def("to_short_code", &PyCommonCode::short_code)
.def("to_string", &PyCommonCode::string, py::arg("shorten") = false)
.def_property_readonly("value", &PyCommonCode::value)
// TODO: add n_1x1 / n_1x2 / n_2x1 / ...
.def_static("check", py::overload_cast<uint64_t>(&PyCommonCode::check))
.def_static("check", py::overload_cast<std::string_view>(&PyCommonCode::check));
}
static void bind_short_code(const py::module_ &mod) {
py::class_<PyShortCode>(mod, "ShortCode")
.def(py::init<uint32_t>())
.def(py::init<PyCommonCode>())
.def(py::init<std::string_view>())
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint32_t())
.def(py::self < uint32_t()).def(py::self <= uint32_t())
.def(py::self > uint32_t()).def(py::self >= uint32_t())
.def(py::hash(py::self))
.def("__str__", &PyShortCode::str)
.def("__int__", &PyShortCode::value)
.def("__repr__", &PyShortCode::repr)
.def_property_readonly("value", &PyShortCode::value)
.def("to_common_code", &PyShortCode::common_code)
.def_static("check", py::overload_cast<uint32_t>(&PyShortCode::check))
.def_static("check", py::overload_cast<std::string_view>(&PyShortCode::check))
.def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false);
}
void bind_klsk_code(const py::module_ &mod) {
bind_short_code(mod);
bind_common_code(mod);
}

28
src/core_ffi/python_ffi/binder/short_code.cc

@ -1,28 +0,0 @@
#include "binder.h"
void bind_short_code(const py::module_ &m) {
py::class_<PyShortCode>(m, "ShortCode")
.def(py::init<uint32_t>())
.def(py::init<PyCommonCode>())
.def(py::init<std::string_view>())
.def(py::hash(py::self))
.def("__str__", &PyShortCode::str)
.def("__int__", &PyShortCode::value)
.def("__repr__", &PyShortCode::repr)
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint32_t())
.def(py::self < uint32_t()).def(py::self <= uint32_t())
.def(py::self > uint32_t()).def(py::self >= uint32_t())
.def_property_readonly("value", &PyShortCode::value)
.def_property_readonly("common_code", &PyShortCode::common_code)
.def_static("check", static_cast<bool (*)(uint32_t)>(&PyShortCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyShortCode::check))
.def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false);
}

21
src/core_ffi/python_ffi/include/binder.h

@ -1,30 +1,15 @@
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
namespace py = pybind11;
#include "py_ffi/cases.h"
#include "py_ffi/group.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
#include "py_ffi/fast_cal.h"
void bind_klsk_code(const py::module_ &mod);
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
using klotski::ffi::PyGroupUnion;
using klotski::ffi::PyGroup;
using klotski::ffi::PyFastCal;
void bind_klsk_cases(const py::module_ &mod);
void bind_cases(const py::module_ &mod);
void bind_short_code(const py::module_ &m);
void bind_common_code(const py::module_ &m);
// TODO: add `bind_klsk_group` and `bind_klsk_fast_cal`
void bind_group_union(const py::module_ &m);
void bind_group(const py::module_ &m);
void bind_fast_cal(const py::module_ &m);

5
src/core_ffi/python_ffi/include/py_ffi/cases.h

@ -3,8 +3,7 @@
#pragma once
#include <variant>
#include "ranges/ranges.h"
#include <ranges/ranges.h>
#include "py_ffi/common_code.h"
@ -30,6 +29,8 @@ class PyCases {
public:
PyCases() = delete;
// TODO: add pickle support
// ------------------------------------------------------------------------------------- //
/// Constructing from r-value.

34
src/core_ffi/python_ffi/include/py_ffi/common_code.h

@ -7,11 +7,9 @@
#include "py_ffi/short_code.h"
namespace klotski::ffi {
using codec::CommonCode;
// TODO: maybe using `PyLayout` instead of `PyCommonCode`
class PyShortCode;
namespace klotski::ffi {
class PyCommonCode {
public:
@ -57,21 +55,17 @@ public:
// ------------------------------------------------------------------------------------- //
[[nodiscard]] std::vector<PyCommonCode> next_cases() const noexcept {
std::vector<PyCommonCode> cases;
auto mover = mover::MaskMover([&cases](const codec::RawCode code, uint64_t) {
cases.emplace_back(std::bit_cast<PyCommonCode>(code.to_common_code()));
});
mover.next_cases(code_.to_raw_code(), 0);
return cases;
}
[[nodiscard]] std::vector<PyCommonCode> next_cases() const noexcept;
// ------------------------------------------------------------------------------------- //
private:
CommonCode code_;
codec::CommonCode code_;
};
static_assert(std::is_trivially_copyable_v<PyCommonCode>);
static_assert(sizeof(PyCommonCode) == sizeof(codec::CommonCode));
// ----------------------------------------------------------------------------------------- //
constexpr auto operator==(const PyCommonCode &lhs, const uint64_t rhs) {
@ -94,17 +88,9 @@ constexpr auto operator<=>(const PyCommonCode &lhs, const PyCommonCode &rhs) {
} // namespace klotski::ffi
// ----------------------------------------------------------------------------------------- //
namespace std {
template <>
struct hash<klotski::ffi::PyCommonCode> {
size_t operator()(const klotski::ffi::PyCommonCode &common_code) const noexcept {
return std::hash<uint64_t>{}(common_code.value());
struct std::hash<klotski::ffi::PyCommonCode> {
size_t operator()(const klotski::ffi::PyCommonCode &code) const noexcept {
return std::hash<uint64_t>{}(code.value());
}
};
} // namespace std
// ----------------------------------------------------------------------------------------- //

4
src/core_ffi/python_ffi/include/py_ffi/fast_cal.h

@ -13,7 +13,7 @@ using fast_cal::FastCal;
class PyFastCal {
public:
explicit PyFastCal(PyCommonCode code)
: fast_cal_(FastCal(std::bit_cast<CommonCode>(code).to_raw_code())) {}
: fast_cal_(FastCal(std::bit_cast<codec::CommonCode>(code).to_raw_code())) {}
// TODO: export solution path directly
std::optional<PyCommonCode> solve() {
@ -32,7 +32,7 @@ public:
[[nodiscard]] std::vector<PyCommonCode> backtrack(PyCommonCode code) const {
std::vector<PyCommonCode> path;
for (auto x : fast_cal_.backtrack(std::bit_cast<CommonCode>(code).to_raw_code())) {
for (auto x : fast_cal_.backtrack(std::bit_cast<codec::CommonCode>(code).to_raw_code())) {
path.emplace_back(std::bit_cast<PyCommonCode>(x.to_common_code()));
}
return path;

22
src/core_ffi/python_ffi/include/py_ffi/short_code.h

@ -4,12 +4,8 @@
#include <short_code/short_code.h>
#include "py_ffi/common_code.h"
namespace klotski::ffi {
using codec::ShortCode;
class PyCommonCode;
class PyShortCode {
@ -55,11 +51,11 @@ public:
/// Build conversion index for ShortCode.
static void speed_up(const bool fast_mode) { // TODO: move to `SpeedUp`
ShortCode::speed_up(fast_mode);
codec::ShortCode::speed_up(fast_mode);
}
private:
ShortCode code_;
codec::ShortCode code_;
};
// ----------------------------------------------------------------------------------------- //
@ -84,17 +80,9 @@ constexpr auto operator<=>(const PyShortCode &lhs, const PyShortCode &rhs) {
} // namespace klotski::ffi
// ----------------------------------------------------------------------------------------- //
namespace std {
template <>
struct hash<klotski::ffi::PyShortCode> {
size_t operator()(const klotski::ffi::PyShortCode &short_code) const noexcept {
return std::hash<uint32_t>{}(short_code.value());
struct std::hash<klotski::ffi::PyShortCode> {
size_t operator()(const klotski::ffi::PyShortCode &code) const noexcept {
return std::hash<uint32_t>{}(code.value());
}
};
} // namespace std
// ----------------------------------------------------------------------------------------- //

25
src/core_ffi/python_ffi/klotski.cc

@ -1,43 +1,26 @@
#include <pybind11/pybind11.h>
// #include <pybind11/stl.h>
// #include "py_ffi/common_code.h"
// #include "py_ffi/short_code.h"
// #include "py_ffi/cases.h"
// #include "py_ffi/group.h"
#include "exception.h"
#include "binder.h"
namespace py = pybind11;
#include "exception.h"
using klotski::ffi::PyExc_CodecError;
using klotski::ffi::PyExc_GroupError;
#include "group/group.h"
#include "all_cases/all_cases.h"
// static PyCases group_demo() {
// auto group_union = klotski::group::GroupUnion::unsafe_create(169);
// auto cases = PyCases::from(group_union.cases());
// return cases;
// }
//
// static PyCases all_cases() {
// return PyCases::from_ref(klotski::cases::AllCases::instance().fetch());
// }
PYBIND11_MODULE(_klotski, m) {
py::register_exception<PyExc_GroupError>(m, "GroupError", PyExc_ValueError);
py::register_exception<PyExc_CodecError>(m, "CodecError", PyExc_ValueError);
// m.def("all_cases", &all_cases);
// m.def("group_demo", &group_demo);
bind_cases(m);
bind_short_code(m);
bind_common_code(m);
bind_klsk_cases(m);
bind_klsk_code(m);
bind_group(m);
bind_group_union(m);
bind_fast_cal(m);

19
src/core_ffi/python_ffi/wrapper/common_codec.cc

@ -4,6 +4,8 @@
#include "py_ffi/common_code.h"
using namespace klotski::ffi;
using klotski::codec::ShortCode;
using klotski::codec::CommonCode;
// ----------------------------------------------------------------------------------------- //
@ -31,12 +33,23 @@ bool PyCommonCode::check(const std::string_view code) noexcept {
// ----------------------------------------------------------------------------------------- //
[[nodiscard]] std::vector<PyCommonCode> PyCommonCode::next_cases() const noexcept {
std::vector<PyCommonCode> cases;
auto mover = mover::MaskMover([&cases](const codec::RawCode code, uint64_t) {
cases.emplace_back(std::bit_cast<PyCommonCode>(code.to_common_code()));
});
mover.next_cases(code_.to_raw_code(), 0);
return cases;
}
// ----------------------------------------------------------------------------------------- //
std::string PyCommonCode::str(const PyCommonCode code) noexcept {
return code.code_.to_string();
}
std::string PyCommonCode::repr(const PyCommonCode code) noexcept {
return std::format("<klotski.CommonCode 0x{}>", str(code));
return std::format("<klotski.Code 0x{}>", str(code));
}
// ----------------------------------------------------------------------------------------- //
@ -49,14 +62,14 @@ static CommonCode convert(const uint64_t code) {
if (CommonCode::check(code)) {
return CommonCode::unsafe_create(code);
}
throw PyExc_CodecError(std::format("invalid common code -> {}", code));
throw PyExc_CodecError(std::format("invalid code: {}", code));
}
static CommonCode convert(const std::string_view code) {
if (const auto str = CommonCode::from_string(code)) {
return str.value();
}
throw PyExc_CodecError(std::format("invalid common code -> `{}`", code));
throw PyExc_CodecError(std::format("invalid code: `{}`", code));
}
PyCommonCode::PyCommonCode(const uint64_t code) : code_(convert(code)) {}

3
src/core_ffi/python_ffi/wrapper/group_union.cc

@ -1,6 +1,9 @@
#include "exception.h"
#include "py_ffi/group.h"
using klotski::codec::ShortCode;
using klotski::codec::CommonCode;
using klotski::ffi::PyExc_GroupError;
using klotski::ffi::PyGroupUnion;

3
src/core_ffi/python_ffi/wrapper/short_code.cc

@ -2,8 +2,11 @@
#include "exception.h"
#include "py_ffi/short_code.h"
#include "py_ffi/common_code.h"
using namespace klotski::ffi;
using klotski::codec::ShortCode;
using klotski::codec::CommonCode;
// ----------------------------------------------------------------------------------------- //

Loading…
Cancel
Save