From b406390638a7d21337b8f38905512fcc0a2a9506 Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sat, 18 May 2024 23:52:30 +0800 Subject: [PATCH] perf: operator overloading for python classes --- src/core_ffi/py_ffi/binder.cc | 61 ++++++++++++++------- src/core_ffi/py_ffi/codec/common_codec.cc | 12 ++--- src/core_ffi/py_ffi/codec/short_code.cc | 11 ++-- src/core_ffi/py_ffi/include/py_codec.h | 66 ++++++++++++++++++++--- 4 files changed, 111 insertions(+), 39 deletions(-) diff --git a/src/core_ffi/py_ffi/binder.cc b/src/core_ffi/py_ffi/binder.cc index dcfaea4..79492da 100644 --- a/src/core_ffi/py_ffi/binder.cc +++ b/src/core_ffi/py_ffi/binder.cc @@ -1,51 +1,74 @@ #include - -// #include -// #include +#include #include "py_exps.h" #include "py_codec.h" namespace py = pybind11; +using klotski::ffi::PyCodecExp; using klotski::ffi::PyShortCode; using klotski::ffi::PyCommonCode; -using klotski::ffi::PyCodecExp; +void bind_common_code(const py::module_ &m) { + py::class_(m, "CommonCode") + .def(py::init()) + .def(py::init()) + .def(py::init()) -PYBIND11_MODULE(klotski, m) { + .def(py::hash(py::self)) + .def("__str__", &PyCommonCode::str) + .def("__int__", &PyCommonCode::value) + .def("__repr__", &PyCommonCode::repr) - py::register_exception(m, "CodecExp", PyExc_ValueError); + .def(py::self == py::self) + .def(py::self < py::self).def(py::self <= py::self) + .def(py::self > py::self).def(py::self >= py::self) + + .def(py::self == uint64_t()) + .def(py::self < uint64_t()).def(py::self <= uint64_t()) + .def(py::self > uint64_t()).def(py::self >= uint64_t()) + + .def_property_readonly("str", &PyCommonCode::string) + .def_property_readonly("value", &PyCommonCode::value) + .def_property_readonly("short_code", &PyCommonCode::short_code) + + .def_static("check", static_cast(&PyCommonCode::check)) + .def_static("check", static_cast(&PyCommonCode::check)); +} +void bind_short_code(const py::module_ &m) { py::class_(m, "ShortCode") .def(py::init()) .def(py::init()) .def(py::init()) + .def(py::hash(py::self)) .def("__str__", &PyShortCode::str) + .def("__int__", &PyShortCode::value) .def("__repr__", &PyShortCode::repr) + .def(py::self == py::self) + .def(py::self < py::self).def(py::self <= py::self) + .def(py::self > py::self).def(py::self >= py::self) + + .def(py::self == uint32_t()) + .def(py::self < uint32_t()).def(py::self <= uint32_t()) + .def(py::self > uint32_t()).def(py::self >= uint32_t()) + .def_property_readonly("value", &PyShortCode::value) .def_property_readonly("common_code", &PyShortCode::common_code) .def_static("check", static_cast(&PyShortCode::check)) .def_static("check", static_cast(&PyShortCode::check)) .def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false); +} - py::class_(m, "CommonCode") - .def(py::init()) - .def(py::init()) - .def(py::init()) - - .def("__str__", &PyCommonCode::str) - .def("__repr__", &PyCommonCode::repr) - - .def_property_readonly("value", &PyCommonCode::value) - .def_property_readonly("short_code", &PyCommonCode::short_code) - .def("string", &PyCommonCode::string, py::arg("shorten") = false) +PYBIND11_MODULE(klotski, m) { + py::register_exception(m, "CodecExp", PyExc_ValueError); - .def_static("check", static_cast(&PyCommonCode::check)) - .def_static("check", static_cast(&PyCommonCode::check)); + bind_short_code(m); + bind_common_code(m); m.attr("__version__") = "version field"; } diff --git a/src/core_ffi/py_ffi/codec/common_codec.cc b/src/core_ffi/py_ffi/codec/common_codec.cc index 1e533b2..d474fe0 100644 --- a/src/core_ffi/py_ffi/codec/common_codec.cc +++ b/src/core_ffi/py_ffi/codec/common_codec.cc @@ -17,8 +17,8 @@ uint64_t PyCommonCode::value() const { return code_.unwrap(); } -std::string PyCommonCode::string(const bool shorten) const { - return code_.to_string(shorten); +std::string PyCommonCode::string() const { + return code_.to_string(true); } PyShortCode PyCommonCode::short_code() const { @@ -32,14 +32,13 @@ bool PyCommonCode::check(const uint64_t code) { } bool PyCommonCode::check(const std::string_view code) { - // TODO: using `std::string_view` in from_string - return CommonCode::from_string(std::string {code}).has_value(); + return CommonCode::from_string(code).has_value(); } // ----------------------------------------------------------------------------------------- // std::string PyCommonCode::str(const PyCommonCode code) { - return code.string(false); + return code.code_.to_string(); } std::string PyCommonCode::repr(const PyCommonCode code) { @@ -60,8 +59,7 @@ static CommonCode convert(const uint64_t code) { } static CommonCode convert(const std::string_view code) { - // TODO: using `std::string_view` in from_string - if (const auto str = CommonCode::from_string(std::string {code})) { + if (const auto str = CommonCode::from_string(code)) { return str.value(); } throw PyCodecExp(std::format("invalid common code -> `{}`", code)); diff --git a/src/core_ffi/py_ffi/codec/short_code.cc b/src/core_ffi/py_ffi/codec/short_code.cc index 672b5df..6bb6f49 100644 --- a/src/core_ffi/py_ffi/codec/short_code.cc +++ b/src/core_ffi/py_ffi/codec/short_code.cc @@ -28,8 +28,7 @@ bool PyShortCode::check(const uint32_t code) { } bool PyShortCode::check(const std::string_view code) { - // TODO: using `std::string_view` in from_string - return ShortCode::from_string(std::string {code}).has_value(); + return ShortCode::from_string(code).has_value(); } void PyShortCode::speed_up(const bool fast_mode) { @@ -39,12 +38,11 @@ void PyShortCode::speed_up(const bool fast_mode) { // ----------------------------------------------------------------------------------------- // std::string PyShortCode::str(const PyShortCode code) { - return std::bit_cast(code).to_string(); + return code.code_.to_string(); } std::string PyShortCode::repr(const PyShortCode code) { - const auto str = code.code_.to_string(); - return std::format("", code.value(), str); + return std::format("", code.value(), str(code)); } // ----------------------------------------------------------------------------------------- // @@ -61,8 +59,7 @@ static ShortCode convert(const uint32_t code) { } static ShortCode convert(const std::string_view code) { - // TODO: using `std::string_view` in from_string - if (const auto str = ShortCode::from_string(std::string {code})) { + if (const auto str = ShortCode::from_string(code)) { return str.value(); } throw PyCodecExp(std::format("invalid short code -> `{}`", code)); diff --git a/src/core_ffi/py_ffi/include/py_codec.h b/src/core_ffi/py_ffi/include/py_codec.h index fcf8755..278087c 100644 --- a/src/core_ffi/py_ffi/include/py_codec.h +++ b/src/core_ffi/py_ffi/include/py_codec.h @@ -10,7 +10,7 @@ using codec::CommonCode; class PyCommonCode; -// ------------------------------------------------------------------------------------- // +// ----------------------------------------------------------------------------------------- // class PyShortCode { public: @@ -43,7 +43,25 @@ private: ShortCode code_; }; -// ------------------------------------------------------------------------------------- // +// ----------------------------------------------------------------------------------------- // + +constexpr auto operator==(const PyShortCode &lhs, const uint32_t rhs) { + return lhs.value() == rhs; +} + +constexpr auto operator<=>(const PyShortCode &lhs, const uint32_t rhs) { + return lhs.value() <=> rhs; +} + +constexpr auto operator==(const PyShortCode &lhs, const PyShortCode &rhs) { + return lhs.value() == rhs.value(); +} + +constexpr auto operator<=>(const PyShortCode &lhs, const PyShortCode &rhs) { + return lhs.value() <=> rhs.value(); +} + +// ----------------------------------------------------------------------------------------- // class PyCommonCode { public: @@ -54,12 +72,12 @@ public: /// Get original value. [[nodiscard]] uint64_t value() const; + /// Convert as shorten string form. + [[nodiscard]] std::string string() const; + /// Convert CommonCode to ShortCode. [[nodiscard]] PyShortCode short_code() const; - /// Convert as string form. - [[nodiscard]] std::string string(bool shorten) const; - /// Verify CommonCode in u64 form. static bool check(uint64_t code); @@ -76,6 +94,42 @@ private: CommonCode code_; }; -// ------------------------------------------------------------------------------------- // +// ----------------------------------------------------------------------------------------- // + +constexpr auto operator==(const PyCommonCode &lhs, const uint64_t rhs) { + return lhs.value() == rhs; +} + +constexpr auto operator<=>(const PyCommonCode &lhs, const uint64_t rhs) { + return lhs.value() <=> rhs; +} + +constexpr auto operator==(const PyCommonCode &lhs, const PyCommonCode &rhs) { + return lhs.value() == rhs.value(); +} + +constexpr auto operator<=>(const PyCommonCode &lhs, const PyCommonCode &rhs) { + return lhs.value() <=> rhs.value(); +} + +// ----------------------------------------------------------------------------------------- // } // namespace klotski::ffi + +// ----------------------------------------------------------------------------------------- // + +template<> +struct std::hash { + size_t operator()(const klotski::ffi::PyShortCode &short_code) const noexcept { + return std::hash()(short_code.value()); + } +}; + +template<> +struct std::hash { + size_t operator()(const klotski::ffi::PyCommonCode &common_code) const noexcept { + return std::hash()(common_code.value()); + } +}; + +// ----------------------------------------------------------------------------------------- //