Browse Source

perf: operator overloading for python classes

master
Dnomd343 2 months ago
parent
commit
b406390638
  1. 61
      src/core_ffi/py_ffi/binder.cc
  2. 12
      src/core_ffi/py_ffi/codec/common_codec.cc
  3. 11
      src/core_ffi/py_ffi/codec/short_code.cc
  4. 66
      src/core_ffi/py_ffi/include/py_codec.h

61
src/core_ffi/py_ffi/binder.cc

@ -1,51 +1,74 @@
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/operators.h>
// #include <pybind11/stl.h>
// #include <pybind11/operators.h>
#include "py_exps.h" #include "py_exps.h"
#include "py_codec.h" #include "py_codec.h"
namespace py = pybind11; namespace py = pybind11;
using klotski::ffi::PyCodecExp;
using klotski::ffi::PyShortCode; using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode; using klotski::ffi::PyCommonCode;
using klotski::ffi::PyCodecExp; void bind_common_code(const py::module_ &m) {
py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>())
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
PYBIND11_MODULE(klotski, m) { .def(py::hash(py::self))
.def("__str__", &PyCommonCode::str)
.def("__int__", &PyCommonCode::value)
.def("__repr__", &PyCommonCode::repr)
py::register_exception<PyCodecExp>(m, "CodecExp", PyExc_ValueError); .def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint64_t())
.def(py::self < uint64_t()).def(py::self <= uint64_t())
.def(py::self > uint64_t()).def(py::self >= uint64_t())
.def_property_readonly("str", &PyCommonCode::string)
.def_property_readonly("value", &PyCommonCode::value)
.def_property_readonly("short_code", &PyCommonCode::short_code)
.def_static("check", static_cast<bool (*)(uint64_t)>(&PyCommonCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyCommonCode::check));
}
void bind_short_code(const py::module_ &m) {
py::class_<PyShortCode>(m, "ShortCode") py::class_<PyShortCode>(m, "ShortCode")
.def(py::init<uint32_t>()) .def(py::init<uint32_t>())
.def(py::init<PyCommonCode>()) .def(py::init<PyCommonCode>())
.def(py::init<std::string_view>()) .def(py::init<std::string_view>())
.def(py::hash(py::self))
.def("__str__", &PyShortCode::str) .def("__str__", &PyShortCode::str)
.def("__int__", &PyShortCode::value)
.def("__repr__", &PyShortCode::repr) .def("__repr__", &PyShortCode::repr)
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint32_t())
.def(py::self < uint32_t()).def(py::self <= uint32_t())
.def(py::self > uint32_t()).def(py::self >= uint32_t())
.def_property_readonly("value", &PyShortCode::value) .def_property_readonly("value", &PyShortCode::value)
.def_property_readonly("common_code", &PyShortCode::common_code) .def_property_readonly("common_code", &PyShortCode::common_code)
.def_static("check", static_cast<bool (*)(uint32_t)>(&PyShortCode::check)) .def_static("check", static_cast<bool (*)(uint32_t)>(&PyShortCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyShortCode::check)) .def_static("check", static_cast<bool (*)(std::string_view)>(&PyShortCode::check))
.def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false); .def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false);
}
py::class_<PyCommonCode>(m, "CommonCode") PYBIND11_MODULE(klotski, m) {
.def(py::init<uint64_t>()) py::register_exception<PyCodecExp>(m, "CodecExp", PyExc_ValueError);
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
.def("__str__", &PyCommonCode::str)
.def("__repr__", &PyCommonCode::repr)
.def_property_readonly("value", &PyCommonCode::value)
.def_property_readonly("short_code", &PyCommonCode::short_code)
.def("string", &PyCommonCode::string, py::arg("shorten") = false)
.def_static("check", static_cast<bool (*)(uint64_t)>(&PyCommonCode::check)) bind_short_code(m);
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyCommonCode::check)); bind_common_code(m);
m.attr("__version__") = "version field"; m.attr("__version__") = "version field";
} }

12
src/core_ffi/py_ffi/codec/common_codec.cc

@ -17,8 +17,8 @@ uint64_t PyCommonCode::value() const {
return code_.unwrap(); return code_.unwrap();
} }
std::string PyCommonCode::string(const bool shorten) const { std::string PyCommonCode::string() const {
return code_.to_string(shorten); return code_.to_string(true);
} }
PyShortCode PyCommonCode::short_code() const { PyShortCode PyCommonCode::short_code() const {
@ -32,14 +32,13 @@ bool PyCommonCode::check(const uint64_t code) {
} }
bool PyCommonCode::check(const std::string_view code) { bool PyCommonCode::check(const std::string_view code) {
// TODO: using `std::string_view` in from_string return CommonCode::from_string(code).has_value();
return CommonCode::from_string(std::string {code}).has_value();
} }
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
std::string PyCommonCode::str(const PyCommonCode code) { std::string PyCommonCode::str(const PyCommonCode code) {
return code.string(false); return code.code_.to_string();
} }
std::string PyCommonCode::repr(const PyCommonCode code) { std::string PyCommonCode::repr(const PyCommonCode code) {
@ -60,8 +59,7 @@ static CommonCode convert(const uint64_t code) {
} }
static CommonCode convert(const std::string_view code) { static CommonCode convert(const std::string_view code) {
// TODO: using `std::string_view` in from_string if (const auto str = CommonCode::from_string(code)) {
if (const auto str = CommonCode::from_string(std::string {code})) {
return str.value(); return str.value();
} }
throw PyCodecExp(std::format("invalid common code -> `{}`", code)); throw PyCodecExp(std::format("invalid common code -> `{}`", code));

11
src/core_ffi/py_ffi/codec/short_code.cc

@ -28,8 +28,7 @@ bool PyShortCode::check(const uint32_t code) {
} }
bool PyShortCode::check(const std::string_view code) { bool PyShortCode::check(const std::string_view code) {
// TODO: using `std::string_view` in from_string return ShortCode::from_string(code).has_value();
return ShortCode::from_string(std::string {code}).has_value();
} }
void PyShortCode::speed_up(const bool fast_mode) { void PyShortCode::speed_up(const bool fast_mode) {
@ -39,12 +38,11 @@ void PyShortCode::speed_up(const bool fast_mode) {
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
std::string PyShortCode::str(const PyShortCode code) { std::string PyShortCode::str(const PyShortCode code) {
return std::bit_cast<ShortCode>(code).to_string(); return code.code_.to_string();
} }
std::string PyShortCode::repr(const PyShortCode code) { std::string PyShortCode::repr(const PyShortCode code) {
const auto str = code.code_.to_string(); return std::format("<klotski.ShortCode {} @{}>", code.value(), str(code));
return std::format("<klotski.ShortCode {} @{}>", code.value(), str);
} }
// ----------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
@ -61,8 +59,7 @@ static ShortCode convert(const uint32_t code) {
} }
static ShortCode convert(const std::string_view code) { static ShortCode convert(const std::string_view code) {
// TODO: using `std::string_view` in from_string if (const auto str = ShortCode::from_string(code)) {
if (const auto str = ShortCode::from_string(std::string {code})) {
return str.value(); return str.value();
} }
throw PyCodecExp(std::format("invalid short code -> `{}`", code)); throw PyCodecExp(std::format("invalid short code -> `{}`", code));

66
src/core_ffi/py_ffi/include/py_codec.h

@ -10,7 +10,7 @@ using codec::CommonCode;
class PyCommonCode; class PyCommonCode;
// ------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
class PyShortCode { class PyShortCode {
public: public:
@ -43,7 +43,25 @@ private:
ShortCode code_; ShortCode code_;
}; };
// ------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
constexpr auto operator==(const PyShortCode &lhs, const uint32_t rhs) {
return lhs.value() == rhs;
}
constexpr auto operator<=>(const PyShortCode &lhs, const uint32_t rhs) {
return lhs.value() <=> rhs;
}
constexpr auto operator==(const PyShortCode &lhs, const PyShortCode &rhs) {
return lhs.value() == rhs.value();
}
constexpr auto operator<=>(const PyShortCode &lhs, const PyShortCode &rhs) {
return lhs.value() <=> rhs.value();
}
// ----------------------------------------------------------------------------------------- //
class PyCommonCode { class PyCommonCode {
public: public:
@ -54,12 +72,12 @@ public:
/// Get original value. /// Get original value.
[[nodiscard]] uint64_t value() const; [[nodiscard]] uint64_t value() const;
/// Convert as shorten string form.
[[nodiscard]] std::string string() const;
/// Convert CommonCode to ShortCode. /// Convert CommonCode to ShortCode.
[[nodiscard]] PyShortCode short_code() const; [[nodiscard]] PyShortCode short_code() const;
/// Convert as string form.
[[nodiscard]] std::string string(bool shorten) const;
/// Verify CommonCode in u64 form. /// Verify CommonCode in u64 form.
static bool check(uint64_t code); static bool check(uint64_t code);
@ -76,6 +94,42 @@ private:
CommonCode code_; CommonCode code_;
}; };
// ------------------------------------------------------------------------------------- // // ----------------------------------------------------------------------------------------- //
constexpr auto operator==(const PyCommonCode &lhs, const uint64_t rhs) {
return lhs.value() == rhs;
}
constexpr auto operator<=>(const PyCommonCode &lhs, const uint64_t rhs) {
return lhs.value() <=> rhs;
}
constexpr auto operator==(const PyCommonCode &lhs, const PyCommonCode &rhs) {
return lhs.value() == rhs.value();
}
constexpr auto operator<=>(const PyCommonCode &lhs, const PyCommonCode &rhs) {
return lhs.value() <=> rhs.value();
}
// ----------------------------------------------------------------------------------------- //
} // namespace klotski::ffi } // namespace klotski::ffi
// ----------------------------------------------------------------------------------------- //
template<>
struct std::hash<klotski::ffi::PyShortCode> {
size_t operator()(const klotski::ffi::PyShortCode &short_code) const noexcept {
return std::hash<uint32_t>()(short_code.value());
}
};
template<>
struct std::hash<klotski::ffi::PyCommonCode> {
size_t operator()(const klotski::ffi::PyCommonCode &common_code) const noexcept {
return std::hash<uint64_t>()(common_code.value());
}
};
// ----------------------------------------------------------------------------------------- //

Loading…
Cancel
Save