Browse Source

perf: operator overloading for python classes

master
Dnomd343 6 months ago
parent
commit
b406390638
  1. 61
      src/core_ffi/py_ffi/binder.cc
  2. 12
      src/core_ffi/py_ffi/codec/common_codec.cc
  3. 11
      src/core_ffi/py_ffi/codec/short_code.cc
  4. 66
      src/core_ffi/py_ffi/include/py_codec.h

61
src/core_ffi/py_ffi/binder.cc

@ -1,51 +1,74 @@
#include <pybind11/pybind11.h>
// #include <pybind11/stl.h>
// #include <pybind11/operators.h>
#include <pybind11/operators.h>
#include "py_exps.h"
#include "py_codec.h"
namespace py = pybind11;
using klotski::ffi::PyCodecExp;
using klotski::ffi::PyShortCode;
using klotski::ffi::PyCommonCode;
using klotski::ffi::PyCodecExp;
void bind_common_code(const py::module_ &m) {
py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>())
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
PYBIND11_MODULE(klotski, m) {
.def(py::hash(py::self))
.def("__str__", &PyCommonCode::str)
.def("__int__", &PyCommonCode::value)
.def("__repr__", &PyCommonCode::repr)
py::register_exception<PyCodecExp>(m, "CodecExp", PyExc_ValueError);
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint64_t())
.def(py::self < uint64_t()).def(py::self <= uint64_t())
.def(py::self > uint64_t()).def(py::self >= uint64_t())
.def_property_readonly("str", &PyCommonCode::string)
.def_property_readonly("value", &PyCommonCode::value)
.def_property_readonly("short_code", &PyCommonCode::short_code)
.def_static("check", static_cast<bool (*)(uint64_t)>(&PyCommonCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyCommonCode::check));
}
void bind_short_code(const py::module_ &m) {
py::class_<PyShortCode>(m, "ShortCode")
.def(py::init<uint32_t>())
.def(py::init<PyCommonCode>())
.def(py::init<std::string_view>())
.def(py::hash(py::self))
.def("__str__", &PyShortCode::str)
.def("__int__", &PyShortCode::value)
.def("__repr__", &PyShortCode::repr)
.def(py::self == py::self)
.def(py::self < py::self).def(py::self <= py::self)
.def(py::self > py::self).def(py::self >= py::self)
.def(py::self == uint32_t())
.def(py::self < uint32_t()).def(py::self <= uint32_t())
.def(py::self > uint32_t()).def(py::self >= uint32_t())
.def_property_readonly("value", &PyShortCode::value)
.def_property_readonly("common_code", &PyShortCode::common_code)
.def_static("check", static_cast<bool (*)(uint32_t)>(&PyShortCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyShortCode::check))
.def_static("speed_up", &PyShortCode::speed_up, py::arg("fast_mode") = false);
}
py::class_<PyCommonCode>(m, "CommonCode")
.def(py::init<uint64_t>())
.def(py::init<PyShortCode>())
.def(py::init<std::string_view>())
.def("__str__", &PyCommonCode::str)
.def("__repr__", &PyCommonCode::repr)
.def_property_readonly("value", &PyCommonCode::value)
.def_property_readonly("short_code", &PyCommonCode::short_code)
.def("string", &PyCommonCode::string, py::arg("shorten") = false)
PYBIND11_MODULE(klotski, m) {
py::register_exception<PyCodecExp>(m, "CodecExp", PyExc_ValueError);
.def_static("check", static_cast<bool (*)(uint64_t)>(&PyCommonCode::check))
.def_static("check", static_cast<bool (*)(std::string_view)>(&PyCommonCode::check));
bind_short_code(m);
bind_common_code(m);
m.attr("__version__") = "version field";
}

12
src/core_ffi/py_ffi/codec/common_codec.cc

@ -17,8 +17,8 @@ uint64_t PyCommonCode::value() const {
return code_.unwrap();
}
std::string PyCommonCode::string(const bool shorten) const {
return code_.to_string(shorten);
std::string PyCommonCode::string() const {
return code_.to_string(true);
}
PyShortCode PyCommonCode::short_code() const {
@ -32,14 +32,13 @@ bool PyCommonCode::check(const uint64_t code) {
}
bool PyCommonCode::check(const std::string_view code) {
// TODO: using `std::string_view` in from_string
return CommonCode::from_string(std::string {code}).has_value();
return CommonCode::from_string(code).has_value();
}
// ----------------------------------------------------------------------------------------- //
std::string PyCommonCode::str(const PyCommonCode code) {
return code.string(false);
return code.code_.to_string();
}
std::string PyCommonCode::repr(const PyCommonCode code) {
@ -60,8 +59,7 @@ static CommonCode convert(const uint64_t code) {
}
static CommonCode convert(const std::string_view code) {
// TODO: using `std::string_view` in from_string
if (const auto str = CommonCode::from_string(std::string {code})) {
if (const auto str = CommonCode::from_string(code)) {
return str.value();
}
throw PyCodecExp(std::format("invalid common code -> `{}`", code));

11
src/core_ffi/py_ffi/codec/short_code.cc

@ -28,8 +28,7 @@ bool PyShortCode::check(const uint32_t code) {
}
bool PyShortCode::check(const std::string_view code) {
// TODO: using `std::string_view` in from_string
return ShortCode::from_string(std::string {code}).has_value();
return ShortCode::from_string(code).has_value();
}
void PyShortCode::speed_up(const bool fast_mode) {
@ -39,12 +38,11 @@ void PyShortCode::speed_up(const bool fast_mode) {
// ----------------------------------------------------------------------------------------- //
std::string PyShortCode::str(const PyShortCode code) {
return std::bit_cast<ShortCode>(code).to_string();
return code.code_.to_string();
}
std::string PyShortCode::repr(const PyShortCode code) {
const auto str = code.code_.to_string();
return std::format("<klotski.ShortCode {} @{}>", code.value(), str);
return std::format("<klotski.ShortCode {} @{}>", code.value(), str(code));
}
// ----------------------------------------------------------------------------------------- //
@ -61,8 +59,7 @@ static ShortCode convert(const uint32_t code) {
}
static ShortCode convert(const std::string_view code) {
// TODO: using `std::string_view` in from_string
if (const auto str = ShortCode::from_string(std::string {code})) {
if (const auto str = ShortCode::from_string(code)) {
return str.value();
}
throw PyCodecExp(std::format("invalid short code -> `{}`", code));

66
src/core_ffi/py_ffi/include/py_codec.h

@ -10,7 +10,7 @@ using codec::CommonCode;
class PyCommonCode;
// ------------------------------------------------------------------------------------- //
// ----------------------------------------------------------------------------------------- //
class PyShortCode {
public:
@ -43,7 +43,25 @@ private:
ShortCode code_;
};
// ------------------------------------------------------------------------------------- //
// ----------------------------------------------------------------------------------------- //
constexpr auto operator==(const PyShortCode &lhs, const uint32_t rhs) {
return lhs.value() == rhs;
}
constexpr auto operator<=>(const PyShortCode &lhs, const uint32_t rhs) {
return lhs.value() <=> rhs;
}
constexpr auto operator==(const PyShortCode &lhs, const PyShortCode &rhs) {
return lhs.value() == rhs.value();
}
constexpr auto operator<=>(const PyShortCode &lhs, const PyShortCode &rhs) {
return lhs.value() <=> rhs.value();
}
// ----------------------------------------------------------------------------------------- //
class PyCommonCode {
public:
@ -54,12 +72,12 @@ public:
/// Get original value.
[[nodiscard]] uint64_t value() const;
/// Convert as shorten string form.
[[nodiscard]] std::string string() const;
/// Convert CommonCode to ShortCode.
[[nodiscard]] PyShortCode short_code() const;
/// Convert as string form.
[[nodiscard]] std::string string(bool shorten) const;
/// Verify CommonCode in u64 form.
static bool check(uint64_t code);
@ -76,6 +94,42 @@ private:
CommonCode code_;
};
// ------------------------------------------------------------------------------------- //
// ----------------------------------------------------------------------------------------- //
constexpr auto operator==(const PyCommonCode &lhs, const uint64_t rhs) {
return lhs.value() == rhs;
}
constexpr auto operator<=>(const PyCommonCode &lhs, const uint64_t rhs) {
return lhs.value() <=> rhs;
}
constexpr auto operator==(const PyCommonCode &lhs, const PyCommonCode &rhs) {
return lhs.value() == rhs.value();
}
constexpr auto operator<=>(const PyCommonCode &lhs, const PyCommonCode &rhs) {
return lhs.value() <=> rhs.value();
}
// ----------------------------------------------------------------------------------------- //
} // namespace klotski::ffi
// ----------------------------------------------------------------------------------------- //
template<>
struct std::hash<klotski::ffi::PyShortCode> {
size_t operator()(const klotski::ffi::PyShortCode &short_code) const noexcept {
return std::hash<uint32_t>()(short_code.value());
}
};
template<>
struct std::hash<klotski::ffi::PyCommonCode> {
size_t operator()(const klotski::ffi::PyCommonCode &common_code) const noexcept {
return std::hash<uint64_t>()(common_code.value());
}
};
// ----------------------------------------------------------------------------------------- //

Loading…
Cancel
Save