diff --git a/src/core_ffi/python_ffi/binder/cases.cc b/src/core_ffi/python_ffi/binder/cases.cc index 3a1eccb..8300caf 100644 --- a/src/core_ffi/python_ffi/binder/cases.cc +++ b/src/core_ffi/python_ffi/binder/cases.cc @@ -1,13 +1,16 @@ #include "binder.h" -void bind_cases(const py::module_ &m) { - py::class_(m, "Cases") +void bind_cases(const py::module_ &mod) { + py::class_(mod, "Cases") .def(py::self == py::self) .def("__len__", &PyCases::size) .def("__repr__", &PyCases::repr) .def("__getitem__", &PyCases::at) - .def("__iter__", &PyCases::codes, py::keep_alive<0, 1>()); + .def("__iter__", &PyCases::iter, py::keep_alive<0, 1>()) + .def_property_readonly_static("all_cases", [](const py::object&) { + return PyCases::all_cases(); + }); - py::class_(m, "CasesIter") + py::class_(mod, "CasesIter") .def("__next__", &PyCasesIter::next); } diff --git a/src/core_ffi/python_ffi/include/binder.h b/src/core_ffi/python_ffi/include/binder.h index 237491e..5d62e9d 100644 --- a/src/core_ffi/python_ffi/include/binder.h +++ b/src/core_ffi/python_ffi/include/binder.h @@ -20,7 +20,7 @@ using klotski::ffi::PyGroupUnion; using klotski::ffi::PyGroup; using klotski::ffi::PyFastCal; -void bind_cases(const py::module_ &m); +void bind_cases(const py::module_ &mod); void bind_short_code(const py::module_ &m); void bind_common_code(const py::module_ &m); diff --git a/src/core_ffi/python_ffi/include/py_ffi/cases.h b/src/core_ffi/python_ffi/include/py_ffi/cases.h index 96380af..75fe1b0 100644 --- a/src/core_ffi/python_ffi/include/py_ffi/cases.h +++ b/src/core_ffi/python_ffi/include/py_ffi/cases.h @@ -3,7 +3,8 @@ #pragma once #include -#include + +#include "ranges/ranges.h" #include "py_ffi/common_code.h" @@ -14,14 +15,14 @@ using cases::RangesUnion; class PyCasesIter { public: /// Construct from RangesUnion reference. - explicit PyCasesIter(const RangesUnion &data); + explicit PyCasesIter(const RangesUnion &data) noexcept; - /// Get the next CommonCode or throw a stop_iteration exception. + /// Get the next CommonCode or throw `stop_iteration` exception. PyCommonCode next(); private: - uint8_t head_ {0}; - uint32_t index_ {0}; + size_t index_ {0}; + uint64_t head_ {0}; const RangesUnion &data_; }; @@ -29,8 +30,6 @@ class PyCases { public: PyCases() = delete; - // TODO: add `all_cases` interface - // ------------------------------------------------------------------------------------- // /// Constructing from r-value. @@ -45,18 +44,21 @@ public: [[nodiscard]] size_t size() const noexcept; /// Get CommonCode iterator of cases. - [[nodiscard]] PyCasesIter codes() const noexcept; + [[nodiscard]] PyCasesIter iter() const noexcept; /// Get the CommonCode of the specified index. - [[nodiscard]] PyCommonCode at(size_t index) const; // TODO: allow `-1` index + [[nodiscard]] PyCommonCode at(int32_t index) const; // ------------------------------------------------------------------------------------- // + /// Export all klotski cases. + static PyCases all_cases() noexcept; + /// Wrapper of `__repr__` method in Python. static std::string repr(const PyCases &cases) noexcept; /// Compare the cases contents of two PyCases. - friend constexpr auto operator==(const PyCases &lhs, const PyCases &rhs); + friend constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) noexcept; // ------------------------------------------------------------------------------------- // @@ -75,7 +77,7 @@ private: // ------------------------------------------------------------------------------------- // }; -constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) { +constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) noexcept { return lhs.data_ref() == rhs.data_ref(); } diff --git a/src/core_ffi/python_ffi/wrapper/cases.cc b/src/core_ffi/python_ffi/wrapper/cases.cc index d00feff..8fae0c3 100644 --- a/src/core_ffi/python_ffi/wrapper/cases.cc +++ b/src/core_ffi/python_ffi/wrapper/cases.cc @@ -6,17 +6,16 @@ namespace py = pybind11; using namespace klotski::ffi; +using klotski::cases::AllCases; // ----------------------------------------------------------------------------------------- // -PyCasesIter::PyCasesIter(const RangesUnion &data) : data_(data) {} +PyCasesIter::PyCasesIter(const RangesUnion &data) noexcept : data_(data) {} PyCommonCode PyCasesIter::next() { while (head_ < 16) { - const auto &ranges = data_.ranges(head_); - if (index_ < ranges.size()) { - auto code = (static_cast(head_) << 32) | ranges[index_++]; - return std::bit_cast(code); + if (const auto &ranges = data_.ranges(head_); index_ < ranges.size()) { + return std::bit_cast((head_ << 32) | ranges[index_++]); } index_ = 0, ++head_; } @@ -26,36 +25,26 @@ PyCommonCode PyCasesIter::next() { // ----------------------------------------------------------------------------------------- // size_t PyCases::size() const noexcept { - size_t num = 0; - for (const auto &x : data_ref()) { // TODO: fetch from RangesUnion.size() - num += x.size(); - } - return num; + return data_ref().size(); } -PyCasesIter PyCases::codes() const noexcept { +PyCasesIter PyCases::iter() const noexcept { return PyCasesIter(data_ref()); } -PyCommonCode PyCases::at(size_t index) const { - if (index >= size()) { +PyCommonCode PyCases::at(const int32_t index) const { + const auto size_ = static_cast(size()); + if (index >= size_ || index < -size_) { throw py::index_error("cases index out of range"); } + const auto code = data_ref()[index < 0 ? index + size_ : index]; + return std::bit_cast(code); +} - uint64_t head = 0; - for (;;) { - if (index >= data_ref().ranges(head).size()) { - index -= data_ref().ranges(head).size(); - ++head; - } else { - break; - } - } - uint32_t range = data_ref().ranges(head)[index]; +// ----------------------------------------------------------------------------------------- // - // TODO: fetch from RangesUnion[] - const auto code = CommonCode::unsafe_create(head << 32 | range); - return std::bit_cast(code); +PyCases PyCases::all_cases() noexcept { + return from_ref(AllCases::instance().fetch()); } std::string PyCases::repr(const PyCases &cases) noexcept {