Browse Source

update: enhance `PyCases` interfaces

master
Dnomd343 4 weeks ago
parent
commit
a8d03a56f8
  1. 1
      src/core_ffi/CMakeLists.txt
  2. 17
      src/core_ffi/py_ffi/binder.cc
  3. 21
      src/core_ffi/py_ffi/binder/cases.cc
  4. 57
      src/core_ffi/py_ffi/include/py_cases.h
  5. 34
      src/core_ffi/py_ffi/wrapper/cases.cc

1
src/core_ffi/CMakeLists.txt

@ -18,6 +18,7 @@ if (KLSK_PYTHON_FFI)
py_ffi/wrapper/cases.cc
py_ffi/wrapper/py_group_union.cc
py_ffi/wrapper/py_group.cc
py_ffi/binder/cases.cc
)
target_include_directories(klotski_py PRIVATE py_ffi/include)
target_include_directories(klotski_py PRIVATE py_ffi)

17
src/core_ffi/py_ffi/binder.cc

@ -24,6 +24,8 @@ using klotski::ffi::PyExc_CodecError;
using klotski::ffi::PyGroup;
using klotski::ffi::PyGroupUnion;
extern void bind_cases(const py::module_ &m);
void bind_common_code(const py::module_ &m) {
py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>())
@ -100,22 +102,11 @@ PYBIND11_MODULE(klotski, m) {
m.def("all_cases", &all_cases);
m.def("group_demo", &group_demo);
auto py_cases = py::class_<PyCases>(m, "Cases")
.def("size", &PyCases::size)
.def("__iter__", &PyCases::common_codes, py::keep_alive<0, 1>())
.def("short_codes", &PyCases::short_codes, py::keep_alive<0, 1>());
py::class_<PyCases::ShortCodeIter>(py_cases, "ShortCodeIter")
.def("__iter__", [](PyCases::ShortCodeIter &it) -> PyCases::ShortCodeIter& { return it; })
.def("__next__", &PyCases::ShortCodeIter::next);
py::class_<PyCases::CommonCodeIter>(py_cases, "CommonCodeIter")
.def("__iter__", [](PyCases::CommonCodeIter &it) -> PyCases::CommonCodeIter& { return it; })
.def("__next__", &PyCases::CommonCodeIter::next);
bind_short_code(m);
bind_common_code(m);
bind_cases(m);
py::class_<PyGroupUnion>(m, "GroupUnion")
.def(py::init<uint8_t>())
.def(py::init<PyShortCode>())

21
src/core_ffi/py_ffi/binder/cases.cc

@ -0,0 +1,21 @@
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "include/py_cases.h"
namespace py = pybind11;
using klotski::ffi::PyCases;
using klotski::ffi::PyCasesIter;
void bind_cases(const py::module_ &m) {
py::class_<PyCases>(m, "Cases")
.def(py::self == py::self)
.def("__len__", &PyCases::size)
.def("__repr__", &PyCases::repr)
.def("__getitem__", &PyCases::at)
.def("__iter__", &PyCases::codes, py::keep_alive<0, 1>());
py::class_<PyCasesIter>(m, "CasesIter")
.def("__next__", &PyCasesIter::next);
}

57
src/core_ffi/py_ffi/include/py_cases.h

@ -11,35 +11,27 @@ namespace klotski::ffi {
using cases::RangesUnion;
class PyCases {
class PyCasesIter {
public:
PyCases() = delete;
// ------------------------------------------------------------------------------------- //
/// Construct from RangesUnion reference.
explicit PyCasesIter(const RangesUnion &data);
class CommonCodeIter {
public:
PyCommonCode next();
explicit CommonCodeIter(const RangesUnion &data);
/// Get the next CommonCode or throw a stop_iteration exception.
PyCommonCode next();
private:
uint8_t head_ {0};
uint32_t index_ {0};
const RangesUnion &data_;
};
class ShortCodeIter {
public:
PyShortCode next();
explicit ShortCodeIter(CommonCodeIter iter);
private:
uint8_t head_ {0};
uint32_t index_ {0};
const RangesUnion &data_;
};
private:
CommonCodeIter iter_;
};
class PyCases {
public:
PyCases() = delete;
// ------------------------------------------------------------------------------------- //
/// Constructing from rvalue.
/// Constructing from r-value.
static PyCases from(RangesUnion &&data) noexcept;
/// Constructing from longer-lived reference.
@ -48,20 +40,23 @@ public:
// ------------------------------------------------------------------------------------- //
/// Get the number of cases.
[[nodiscard]] size_t size() const;
/// Get ShortCode iterator of cases.
[[nodiscard]] ShortCodeIter short_codes() const;
[[nodiscard]] size_t size() const noexcept;
/// Get CommonCode iterator of cases.
[[nodiscard]] CommonCodeIter common_codes() const;
[[nodiscard]] PyCasesIter codes() const noexcept;
/// Get the CommonCode of the specified index.
[[nodiscard]] PyCommonCode operator[](size_t index) const;
[[nodiscard]] PyCommonCode at(size_t index) const; // TODO: allow `-1` index
// ------------------------------------------------------------------------------------- //
// TODO: add len / repr
/// Wrapper of `__repr__` method in Python.
static std::string repr(const PyCases &cases) noexcept;
/// Compare the cases contents of two PyCases.
friend constexpr auto operator==(const PyCases &lhs, const PyCases &rhs);
// ------------------------------------------------------------------------------------- //
private:
explicit PyCases(RangesUnion &&data);
@ -78,6 +73,8 @@ private:
// ------------------------------------------------------------------------------------- //
};
// TODO: allow compare
constexpr auto operator==(const PyCases &lhs, const PyCases &rhs) {
return lhs.data_ref() == rhs.data_ref();
}
} // namespace klotski::ffi

34
src/core_ffi/py_ffi/wrapper/cases.cc

@ -1,3 +1,4 @@
#include <format>
#include <pybind11/pybind11.h>
#include "include/py_cases.h"
@ -6,35 +7,25 @@ namespace py = pybind11;
using namespace klotski::ffi;
using ShortCodeIter = PyCases::ShortCodeIter;
using CommonCodeIter = PyCases::CommonCodeIter;
// ----------------------------------------------------------------------------------------- //
CommonCodeIter::CommonCodeIter(const RangesUnion &data) : data_(data) {}
ShortCodeIter::ShortCodeIter(PyCases::CommonCodeIter iter) : iter_(iter) {}
PyCasesIter::PyCasesIter(const RangesUnion &data) : data_(data) {}
PyShortCode ShortCodeIter::next() {
return iter_.next().short_code();
}
PyCommonCode CommonCodeIter::next() {
PyCommonCode PyCasesIter::next() {
while (head_ < 16) {
const auto &ranges = data_[head_];
if (index_ < ranges.size()) {
auto code = (static_cast<uint64_t>(head_) << 32) | ranges[index_++];
return std::bit_cast<PyCommonCode>(code);
}
++head_;
index_ = 0;
index_ = 0, ++head_;
}
throw py::stop_iteration();
}
// ----------------------------------------------------------------------------------------- //
size_t PyCases::size() const {
size_t PyCases::size() const noexcept {
size_t num = 0;
for (const auto &x : data_ref()) { // TODO: fetch from RangesUnion.size()
num += x.size();
@ -42,16 +33,11 @@ size_t PyCases::size() const {
return num;
}
auto PyCases::short_codes() const -> ShortCodeIter {
return ShortCodeIter(common_codes());
PyCasesIter PyCases::codes() const noexcept {
return PyCasesIter(data_ref());
}
auto PyCases::common_codes() const -> CommonCodeIter {
return CommonCodeIter(data_ref());
}
PyCommonCode PyCases::operator[](size_t index) const {
PyCommonCode PyCases::at(size_t index) const {
if (index >= size()) {
throw py::index_error("cases index out of range");
}
@ -72,6 +58,10 @@ PyCommonCode PyCases::operator[](size_t index) const {
return std::bit_cast<PyCommonCode>(code);
}
std::string PyCases::repr(const PyCases &cases) noexcept {
return std::format("<klotski.Cases size={}>", cases.size());
}
// ----------------------------------------------------------------------------------------- //
PyCases PyCases::from(RangesUnion &&data) noexcept {

Loading…
Cancel
Save