diff --git a/src/core_ffi/CMakeLists.txt b/src/core_ffi/CMakeLists.txt index c7db17d..c6b8896 100644 --- a/src/core_ffi/CMakeLists.txt +++ b/src/core_ffi/CMakeLists.txt @@ -18,6 +18,7 @@ if (KLSK_PYTHON_FFI) py_ffi/wrapper/cases.cc py_ffi/wrapper/py_group_union.cc py_ffi/wrapper/py_group.cc + py_ffi/binder/cases.cc ) target_include_directories(klotski_py PRIVATE py_ffi/include) target_include_directories(klotski_py PRIVATE py_ffi) diff --git a/src/core_ffi/py_ffi/binder.cc b/src/core_ffi/py_ffi/binder.cc index b1a37e1..d6d5864 100644 --- a/src/core_ffi/py_ffi/binder.cc +++ b/src/core_ffi/py_ffi/binder.cc @@ -24,6 +24,8 @@ using klotski::ffi::PyExc_CodecError; using klotski::ffi::PyGroup; using klotski::ffi::PyGroupUnion; +extern void bind_cases(const py::module_ &m); + void bind_common_code(const py::module_ &m) { py::class_(m, "CommonCode") .def(py::init()) @@ -100,22 +102,11 @@ PYBIND11_MODULE(klotski, m) { m.def("all_cases", &all_cases); m.def("group_demo", &group_demo); - auto py_cases = py::class_(m, "Cases") - .def("size", &PyCases::size) - .def("__iter__", &PyCases::common_codes, py::keep_alive<0, 1>()) - .def("short_codes", &PyCases::short_codes, py::keep_alive<0, 1>()); - - py::class_(py_cases, "ShortCodeIter") - .def("__iter__", [](PyCases::ShortCodeIter &it) -> PyCases::ShortCodeIter& { return it; }) - .def("__next__", &PyCases::ShortCodeIter::next); - - py::class_(py_cases, "CommonCodeIter") - .def("__iter__", [](PyCases::CommonCodeIter &it) -> PyCases::CommonCodeIter& { return it; }) - .def("__next__", &PyCases::CommonCodeIter::next); - bind_short_code(m); bind_common_code(m); + bind_cases(m); + py::class_(m, "GroupUnion") .def(py::init()) .def(py::init()) diff --git a/src/core_ffi/py_ffi/binder/cases.cc b/src/core_ffi/py_ffi/binder/cases.cc new file mode 100644 index 0000000..bec2c00 --- /dev/null +++ b/src/core_ffi/py_ffi/binder/cases.cc @@ -0,0 +1,21 @@ +#include +#include + +#include "include/py_cases.h" + +namespace py = pybind11; + +using klotski::ffi::PyCases; +using klotski::ffi::PyCasesIter; + +void bind_cases(const py::module_ &m) { + py::class_(m, "Cases") + .def(py::self == py::self) + .def("__len__", &PyCases::size) + .def("__repr__", &PyCases::repr) + .def("__getitem__", &PyCases::at) + .def("__iter__", &PyCases::codes, py::keep_alive<0, 1>()); + + py::class_(m, "CasesIter") + .def("__next__", &PyCasesIter::next); +} diff --git a/src/core_ffi/py_ffi/include/py_cases.h b/src/core_ffi/py_ffi/include/py_cases.h index 576ff45..1f4ee6d 100644 --- a/src/core_ffi/py_ffi/include/py_cases.h +++ b/src/core_ffi/py_ffi/include/py_cases.h @@ -11,35 +11,27 @@ namespace klotski::ffi { using cases::RangesUnion; -class PyCases { +class PyCasesIter { public: - PyCases() = delete; - - // ------------------------------------------------------------------------------------- // + /// Construct from RangesUnion reference. + explicit PyCasesIter(const RangesUnion &data); - class CommonCodeIter { - public: - PyCommonCode next(); - explicit CommonCodeIter(const RangesUnion &data); + /// Get the next CommonCode or throw a stop_iteration exception. + PyCommonCode next(); - private: - uint8_t head_ {0}; - uint32_t index_ {0}; - const RangesUnion &data_; - }; - - class ShortCodeIter { - public: - PyShortCode next(); - explicit ShortCodeIter(CommonCodeIter iter); +private: + uint8_t head_ {0}; + uint32_t index_ {0}; + const RangesUnion &data_; +}; - private: - CommonCodeIter iter_; - }; +class PyCases { +public: + PyCases() = delete; // ------------------------------------------------------------------------------------- // - /// Constructing from rvalue. + /// Constructing from r-value. static PyCases from(RangesUnion &&data) noexcept; /// Constructing from longer-lived reference. @@ -48,20 +40,23 @@ public: // ------------------------------------------------------------------------------------- // /// Get the number of cases. - [[nodiscard]] size_t size() const; - - /// Get ShortCode iterator of cases. - [[nodiscard]] ShortCodeIter short_codes() const; + [[nodiscard]] size_t size() const noexcept; /// Get CommonCode iterator of cases. - [[nodiscard]] CommonCodeIter common_codes() const; + [[nodiscard]] PyCasesIter codes() const noexcept; /// Get the CommonCode of the specified index. - [[nodiscard]] PyCommonCode operator[](size_t index) const; + [[nodiscard]] PyCommonCode at(size_t index) const; // TODO: allow `-1` index // ------------------------------------------------------------------------------------- // - // TODO: add len / repr + /// Wrapper of `__repr__` method in Python. + static std::string repr(const PyCases &cases) noexcept; + + /// Compare the cases contents of two PyCases. + friend constexpr auto operator==(const PyCases &lhs, const PyCases &rhs); + + // ------------------------------------------------------------------------------------- // private: explicit PyCases(RangesUnion &&data); @@ -78,6 +73,8 @@ private: // ------------------------------------------------------------------------------------- // }; -// TODO: allow compare +constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) { + return lhs.data_ref() == rhs.data_ref(); +} } // namespace klotski::ffi diff --git a/src/core_ffi/py_ffi/wrapper/cases.cc b/src/core_ffi/py_ffi/wrapper/cases.cc index cc3cba1..f82f8c7 100644 --- a/src/core_ffi/py_ffi/wrapper/cases.cc +++ b/src/core_ffi/py_ffi/wrapper/cases.cc @@ -1,3 +1,4 @@ +#include #include #include "include/py_cases.h" @@ -6,35 +7,25 @@ namespace py = pybind11; using namespace klotski::ffi; -using ShortCodeIter = PyCases::ShortCodeIter; -using CommonCodeIter = PyCases::CommonCodeIter; - // ----------------------------------------------------------------------------------------- // -CommonCodeIter::CommonCodeIter(const RangesUnion &data) : data_(data) {} - -ShortCodeIter::ShortCodeIter(PyCases::CommonCodeIter iter) : iter_(iter) {} +PyCasesIter::PyCasesIter(const RangesUnion &data) : data_(data) {} -PyShortCode ShortCodeIter::next() { - return iter_.next().short_code(); -} - -PyCommonCode CommonCodeIter::next() { +PyCommonCode PyCasesIter::next() { while (head_ < 16) { const auto &ranges = data_[head_]; if (index_ < ranges.size()) { auto code = (static_cast(head_) << 32) | ranges[index_++]; return std::bit_cast(code); } - ++head_; - index_ = 0; + index_ = 0, ++head_; } throw py::stop_iteration(); } // ----------------------------------------------------------------------------------------- // -size_t PyCases::size() const { +size_t PyCases::size() const noexcept { size_t num = 0; for (const auto &x : data_ref()) { // TODO: fetch from RangesUnion.size() num += x.size(); @@ -42,16 +33,11 @@ size_t PyCases::size() const { return num; } -auto PyCases::short_codes() const -> ShortCodeIter { - return ShortCodeIter(common_codes()); +PyCasesIter PyCases::codes() const noexcept { + return PyCasesIter(data_ref()); } -auto PyCases::common_codes() const -> CommonCodeIter { - return CommonCodeIter(data_ref()); -} - -PyCommonCode PyCases::operator[](size_t index) const { - +PyCommonCode PyCases::at(size_t index) const { if (index >= size()) { throw py::index_error("cases index out of range"); } @@ -72,6 +58,10 @@ PyCommonCode PyCases::operator[](size_t index) const { return std::bit_cast(code); } +std::string PyCases::repr(const PyCases &cases) noexcept { + return std::format("", cases.size()); +} + // ----------------------------------------------------------------------------------------- // PyCases PyCases::from(RangesUnion &&data) noexcept {