Browse Source

update: enhance `PyCases` interfaces

legacy
Dnomd343 2 months ago
parent
commit
a8d03a56f8
  1. 1
      src/core_ffi/CMakeLists.txt
  2. 17
      src/core_ffi/py_ffi/binder.cc
  3. 21
      src/core_ffi/py_ffi/binder/cases.cc
  4. 57
      src/core_ffi/py_ffi/include/py_cases.h
  5. 34
      src/core_ffi/py_ffi/wrapper/cases.cc

1
src/core_ffi/CMakeLists.txt

@ -18,6 +18,7 @@ if (KLSK_PYTHON_FFI)
py_ffi/wrapper/cases.cc py_ffi/wrapper/cases.cc
py_ffi/wrapper/py_group_union.cc py_ffi/wrapper/py_group_union.cc
py_ffi/wrapper/py_group.cc py_ffi/wrapper/py_group.cc
py_ffi/binder/cases.cc
) )
target_include_directories(klotski_py PRIVATE py_ffi/include) target_include_directories(klotski_py PRIVATE py_ffi/include)
target_include_directories(klotski_py PRIVATE py_ffi) target_include_directories(klotski_py PRIVATE py_ffi)

17
src/core_ffi/py_ffi/binder.cc

@ -24,6 +24,8 @@ using klotski::ffi::PyExc_CodecError;
using klotski::ffi::PyGroup; using klotski::ffi::PyGroup;
using klotski::ffi::PyGroupUnion; using klotski::ffi::PyGroupUnion;
extern void bind_cases(const py::module_ &m);
void bind_common_code(const py::module_ &m) { void bind_common_code(const py::module_ &m) {
py::class_<PyCommonCode>(m, "CommonCode") py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>()) .def(py::init<uint64_t>())
@ -100,22 +102,11 @@ PYBIND11_MODULE(klotski, m) {
m.def("all_cases", &all_cases); m.def("all_cases", &all_cases);
m.def("group_demo", &group_demo); m.def("group_demo", &group_demo);
auto py_cases = py::class_<PyCases>(m, "Cases")
.def("size", &PyCases::size)
.def("__iter__", &PyCases::common_codes, py::keep_alive<0, 1>())
.def("short_codes", &PyCases::short_codes, py::keep_alive<0, 1>());
py::class_<PyCases::ShortCodeIter>(py_cases, "ShortCodeIter")
.def("__iter__", [](PyCases::ShortCodeIter &it) -> PyCases::ShortCodeIter& { return it; })
.def("__next__", &PyCases::ShortCodeIter::next);
py::class_<PyCases::CommonCodeIter>(py_cases, "CommonCodeIter")
.def("__iter__", [](PyCases::CommonCodeIter &it) -> PyCases::CommonCodeIter& { return it; })
.def("__next__", &PyCases::CommonCodeIter::next);
bind_short_code(m); bind_short_code(m);
bind_common_code(m); bind_common_code(m);
bind_cases(m);
py::class_<PyGroupUnion>(m, "GroupUnion") py::class_<PyGroupUnion>(m, "GroupUnion")
.def(py::init<uint8_t>()) .def(py::init<uint8_t>())
.def(py::init<PyShortCode>()) .def(py::init<PyShortCode>())

21
src/core_ffi/py_ffi/binder/cases.cc

@ -0,0 +1,21 @@
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "include/py_cases.h"
namespace py = pybind11;
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
void bind_cases(const py::module_ &m) {
py::class_<PyCases>(m, "Cases")
.def(py::self == py::self)
.def("__len__", &PyCases::size)
.def("__repr__", &PyCases::repr)
.def("__getitem__", &PyCases::at)
.def("__iter__", &PyCases::codes, py::keep_alive<0, 1>());
py::class_<PyCasesIter>(m, "CasesIter")
.def("__next__", &PyCasesIter::next);
}

57
src/core_ffi/py_ffi/include/py_cases.h

@ -11,35 +11,27 @@ namespace klotski::ffi {
using cases::RangesUnion; using cases::RangesUnion;
class PyCases { class PyCasesIter {
public: public:
PyCases() = delete; /// Construct from RangesUnion reference.
explicit PyCasesIter(const RangesUnion &data);
// ------------------------------------------------------------------------------------- //
class CommonCodeIter { /// Get the next CommonCode or throw a stop_iteration exception.
public: PyCommonCode next();
PyCommonCode next();
explicit CommonCodeIter(const RangesUnion &data);
private: private:
uint8_t head_ {0}; uint8_t head_ {0};
uint32_t index_ {0}; uint32_t index_ {0};
const RangesUnion &data_; const RangesUnion &data_;
}; };
class ShortCodeIter {
public:
PyShortCode next();
explicit ShortCodeIter(CommonCodeIter iter);
private: class PyCases {
CommonCodeIter iter_; public:
}; PyCases() = delete;
// ------------------------------------------------------------------------------------- // // ------------------------------------------------------------------------------------- //
/// Constructing from rvalue. /// Constructing from r-value.
static PyCases from(RangesUnion &&data) noexcept; static PyCases from(RangesUnion &&data) noexcept;
/// Constructing from longer-lived reference. /// Constructing from longer-lived reference.
@ -48,20 +40,23 @@ public:
// ------------------------------------------------------------------------------------- // // ------------------------------------------------------------------------------------- //
/// Get the number of cases. /// Get the number of cases.
[[nodiscard]] size_t size() const; [[nodiscard]] size_t size() const noexcept;
/// Get ShortCode iterator of cases.
[[nodiscard]] ShortCodeIter short_codes() const;
/// Get CommonCode iterator of cases. /// Get CommonCode iterator of cases.
[[nodiscard]] CommonCodeIter common_codes() const; [[nodiscard]] PyCasesIter codes() const noexcept;
/// Get the CommonCode of the specified index. /// Get the CommonCode of the specified index.
[[nodiscard]] PyCommonCode operator[](size_t index) const; [[nodiscard]] PyCommonCode at(size_t index) const; // TODO: allow `-1` index
// ------------------------------------------------------------------------------------- // // ------------------------------------------------------------------------------------- //
// TODO: add len / repr /// Wrapper of `__repr__` method in Python.
static std::string repr(const PyCases &cases) noexcept;
/// Compare the cases contents of two PyCases.
friend constexpr auto operator==(const PyCases &lhs, const PyCases &rhs);
// ------------------------------------------------------------------------------------- //
private: private:
explicit PyCases(RangesUnion &&data); explicit PyCases(RangesUnion &&data);
@ -78,6 +73,8 @@ private:
// ------------------------------------------------------------------------------------- // // ------------------------------------------------------------------------------------- //
}; };
// TODO: allow compare constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) {
return lhs.data_ref() == rhs.data_ref();
}
} // namespace klotski::ffi } // namespace klotski::ffi

34
src/core_ffi/py_ffi/wrapper/cases.cc

@ -1,3 +1,4 @@
#include <format>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "include/py_cases.h" #include "include/py_cases.h"
@ -6,35 +7,25 @@ namespace py = pybind11;
using namespace klotski::ffi; using namespace klotski::ffi;
using ShortCodeIter = PyCases::ShortCodeIter;
using CommonCodeIter = PyCases::CommonCodeIter;
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
CommonCodeIter::CommonCodeIter(const RangesUnion &data) : data_(data) {} PyCasesIter::PyCasesIter(const RangesUnion &data) : data_(data) {}
ShortCodeIter::ShortCodeIter(PyCases::CommonCodeIter iter) : iter_(iter) {}
PyShortCode ShortCodeIter::next() { PyCommonCode PyCasesIter::next() {
return iter_.next().short_code();
}
PyCommonCode CommonCodeIter::next() {
while (head_ < 16) { while (head_ < 16) {
const auto &ranges = data_[head_]; const auto &ranges = data_[head_];
if (index_ < ranges.size()) { if (index_ < ranges.size()) {
auto code = (static_cast<uint64_t>(head_) << 32) | ranges[index_++]; auto code = (static_cast<uint64_t>(head_) << 32) | ranges[index_++];
return std::bit_cast<PyCommonCode>(code); return std::bit_cast<PyCommonCode>(code);
} }
++head_; index_ = 0, ++head_;
index_ = 0;
} }
throw py::stop_iteration(); throw py::stop_iteration();
} }
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
size_t PyCases::size() const { size_t PyCases::size() const noexcept {
size_t num = 0; size_t num = 0;
for (const auto &x : data_ref()) { // TODO: fetch from RangesUnion.size() for (const auto &x : data_ref()) { // TODO: fetch from RangesUnion.size()
num += x.size(); num += x.size();
@ -42,16 +33,11 @@ size_t PyCases::size() const {
return num; return num;
} }
auto PyCases::short_codes() const -> ShortCodeIter { PyCasesIter PyCases::codes() const noexcept {
return ShortCodeIter(common_codes()); return PyCasesIter(data_ref());
} }
auto PyCases::common_codes() const -> CommonCodeIter { PyCommonCode PyCases::at(size_t index) const {
return CommonCodeIter(data_ref());
}
PyCommonCode PyCases::operator[](size_t index) const {
if (index >= size()) { if (index >= size()) {
throw py::index_error("cases index out of range"); throw py::index_error("cases index out of range");
} }
@ -72,6 +58,10 @@ PyCommonCode PyCases::operator[](size_t index) const {
return std::bit_cast<PyCommonCode>(code); return std::bit_cast<PyCommonCode>(code);
} }
std::string PyCases::repr(const PyCases &cases) noexcept {
return std::format("<klotski.Cases size={}>", cases.size());
}
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
PyCases PyCases::from(RangesUnion &&data) noexcept { PyCases PyCases::from(RangesUnion &&data) noexcept {

Loading…
Cancel
Save