From 4d4523d0720bba5d44e50aa5eb7dee42ea9e35dd Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sun, 9 Jun 2024 14:43:04 +0800 Subject: [PATCH] update: several improvements for Group module --- src/core/benchmark/group.cc | 44 +++-- src/core/group/group.h | 48 +++--- .../group/internal/constant/group_union.h | 38 ++--- src/core/group/internal/group.inl | 5 + src/core/group/internal/group_union.cc | 161 ++++-------------- src/core/group/internal/group_union.inl | 13 +- src/core/main.cc | 13 ++ 7 files changed, 130 insertions(+), 192 deletions(-) create mode 100644 src/core/group/internal/group.inl diff --git a/src/core/benchmark/group.cc b/src/core/benchmark/group.cc index d99e022..066d08d 100644 --- a/src/core/benchmark/group.cc +++ b/src/core/benchmark/group.cc @@ -25,7 +25,7 @@ static std::vector all_common_codes() { return codes; } -std::vector common_code_samples(uint64_t num) { +std::vector common_code_samples(uint64_t num) { static auto codes = all_common_codes(); @@ -34,7 +34,7 @@ std::vector common_code_samples(uint64_t num) { // uint64_t offset = 0; uint64_t offset = part_size / 2; - std::vector result; + std::vector result; for (uint64_t i = 0; i < num; ++i) { uint64_t index = i * part_size + offset; @@ -43,22 +43,24 @@ std::vector common_code_samples(uint64_t num) { // uint64_t kk[] {343, 666, 114514, 35324, 123454, 76453, 93411}; // uint64_t index = kk[i % 7]; - result.emplace_back(codes[index]); + result.emplace_back(klotski::codec::CommonCode::unsafe_create(codes[index])); } return result; - } -std::vector raw_code_samples(uint64_t num) { +std::vector raw_code_samples(uint64_t num) { auto codes = common_code_samples(num); - for (auto &code : codes) { - code = klotski::codec::CommonCode::unsafe_create(code).to_raw_code().unwrap(); + std::vector raw_codes; + + raw_codes.reserve(codes.size()); + for (auto code : codes) { + raw_codes.emplace_back(code.to_raw_code()); } - return codes; + return raw_codes; } static void CommonCodeToTypeId(benchmark::State &state) { @@ -69,7 +71,7 @@ static void CommonCodeToTypeId(benchmark::State &state) { for (auto code : samples) { - // volatile auto ret = klotski::cases::common_code_to_type_id(code); + volatile auto ret = klotski::cases::GroupUnion::type_id(code); } } @@ -82,16 +84,16 @@ static void RawCodeToTypeId(benchmark::State &state) { auto samples = raw_code_samples(state.range(0)); - // for (auto code : samples) { - // if (klotski::codec::RawCode::check(code) == false) { - // std::cout << "error" << std::endl; - // } - // } + for (auto code : samples) { + if (klotski::codec::RawCode::check(code.code_) == false) { + std::cout << "error" << std::endl; + } + } for (auto _ : state) { for (auto code : samples) { - // volatile auto ret = klotski::cases::raw_code_to_type_id(code); + volatile auto ret = klotski::cases::GroupUnion::type_id(code); } } @@ -222,12 +224,22 @@ static void RangesDerive(benchmark::State &state) { auto group_union = klotski::cases::GroupUnion::unsafe_create(169); + std::vector unions; + unions.reserve(klotski::cases::TYPE_ID_LIMIT); + for (int type_id = 0; type_id < klotski::cases::TYPE_ID_LIMIT; ++type_id) { + unions.emplace_back(klotski::cases::GroupUnion::create(type_id).value()); + } + for (auto _ : state) { // results.clear(); // results.reserve(klotski::cases::ALL_CASES_NUM[5]); - volatile auto tmp = group_union.cases(); + // volatile auto tmp = group_union.cases(); + + for (auto g_union : unions) { + volatile auto tmp = g_union.cases(); + } } diff --git a/src/core/group/group.h b/src/core/group/group.h index 313f524..18e3f4a 100644 --- a/src/core/group/group.h +++ b/src/core/group/group.h @@ -75,11 +75,6 @@ namespace klotski::cases { constexpr uint32_t TYPE_ID_LIMIT = 203; constexpr uint32_t ALL_GROUP_NUM = 25422; -// uint32_t common_code_to_type_id(uint64_t common_code); -// uint32_t raw_code_to_type_id(uint64_t raw_code); - -// std::vector group_extend_from_seed(uint64_t raw_code); - class Group; // TODO: add constexpr @@ -90,28 +85,31 @@ public: // ------------------------------------------------------------------------------------- // /// Get the original type id. - [[nodiscard]] uint32_t unwrap() const; + [[nodiscard]] constexpr uint32_t unwrap() const; /// Create GroupUnion without any check. - static GroupUnion unsafe_create(uint32_t type_id); + static constexpr GroupUnion unsafe_create(uint32_t type_id); /// Create GroupUnion with validity check. - static std::optional create(uint32_t type_id); + static constexpr std::optional create(uint32_t type_id); // ------------------------------------------------------------------------------------- // /// Get the number of cases contained. - [[nodiscard]] uint32_t size() const; + [[nodiscard]] constexpr uint32_t size() const; /// Get the number of groups contained. - [[nodiscard]] uint32_t group_num() const; + [[nodiscard]] constexpr uint32_t group_num() const; /// Get the upper limit of the group size. - [[nodiscard]] uint32_t max_group_size() const; + [[nodiscard]] constexpr uint32_t max_group_size() const; + + // ------------------------------------------------------------------------------------- // + /// Get all cases under the current type id. [[nodiscard]] RangesUnion cases() const; - /// Get all group instances under the current type id. + /// Get all groups under the current type id. [[nodiscard]] std::vector groups() const; /// Get the group instance with the specified group id. @@ -131,15 +129,15 @@ public: // ------------------------------------------------------------------------------------- // private: - uint32_t type_id_ {}; + uint32_t type_id_; // ------------------------------------------------------------------------------------- // /// Get the type id of RawCode. - static uint32_t type_id(codec::RawCode raw_code); + static KLSK_INLINE uint32_t type_id(codec::RawCode raw_code); /// Get the type id of CommonCode. - static uint32_t type_id(codec::CommonCode common_code); + static KLSK_INLINE uint32_t type_id(codec::CommonCode common_code); // ------------------------------------------------------------------------------------- // }; @@ -151,25 +149,36 @@ class Group { public: Group() = delete; + // ------------------------------------------------------------------------------------- // + /// Create Group without any check. static Group unsafe_create(uint32_t type_id, uint32_t group_id); /// Create Group with validity check. static std::optional create(uint32_t type_id, uint32_t group_id); + // ------------------------------------------------------------------------------------- // + + /// Get the number of cases contained. [[nodiscard]] uint32_t size() const; + /// Get all cases under the current group. [[nodiscard]] RangesUnion cases() const; + // ------------------------------------------------------------------------------------- // + static Group from_raw_code(codec::RawCode raw_code); + + static Group from_short_code(codec::ShortCode short_code); + static Group from_common_code(codec::CommonCode common_code); -// private: - uint32_t type_id_; - uint32_t group_id_; + // ------------------------------------------------------------------------------------- // - [[nodiscard]] uint32_t flat_id() const; +private: + uint32_t flat_id_; + // TODO: maybe we can using `std::vector` static std::vector extend(codec::RawCode raw_code); }; @@ -204,3 +213,4 @@ private: } // namespace klotski::cases #include "internal/group_union.inl" +#include "internal/group.inl" diff --git a/src/core/group/internal/constant/group_union.h b/src/core/group/internal/constant/group_union.h index ca77e0c..2f6c29f 100644 --- a/src/core/group/internal/constant/group_union.h +++ b/src/core/group/internal/constant/group_union.h @@ -102,36 +102,36 @@ constexpr auto GROUP_UNION_CASES_NUM = std::to_array - -// #include -// #include -// #include - -#include -#include +using klotski::codec::RawCode; +using klotski::codec::CommonCode; +using klotski::cases::GroupUnion; -#include "constant/group_union.h" +#define RANGE_DERIVE(HEAD) ranges.derive(HEAD, cases[HEAD]) -static KLSK_INLINE uint32_t type_id(const int n, const int n_2x1, const int n_1x1) { +static KLSK_INLINE uint32_t to_type_id(const int n, const int n_2x1, const int n_1x1) { constexpr int offset[8] = {0, 15, 41, 74, 110, 145, 175, 196}; return offset[n] + (15 - n * 2) * n_2x1 + n_1x1; } -static uint32_t common_code_to_type_id(const uint64_t common_code) { - const auto range = static_cast(common_code); +uint32_t GroupUnion::type_id(const CommonCode common_code) { + const auto range = static_cast(common_code.unwrap()); const auto n_1x1 = std::popcount((range >> 1) & range & 0x55555555); const auto n_2x1 = std::popcount((range >> 1) & ~range & 0x55555555); - return type_id(std::popcount(range) - n_1x1 * 2, n_2x1, n_1x1); -} - -static uint32_t raw_code_to_type_id(const uint64_t raw_code) { - const auto n = std::popcount(((raw_code >> 1) ^ raw_code) & 0x0249249249249249); - const auto n_2x1 = std::popcount((raw_code >> 1) & ~raw_code & 0x0249249249249249); - const auto n_1x1 = std::popcount((raw_code >> 1) & raw_code & 0x0249249249249249) - n - 3; - return type_id(n, n_2x1, n_1x1); -} - -uint32_t klotski::cases::GroupUnion::type_id(codec::CommonCode common_code) { - return common_code_to_type_id(common_code.unwrap()); + return to_type_id(std::popcount(range) - n_1x1 * 2, n_2x1, n_1x1); } -uint32_t klotski::cases::GroupUnion::type_id(codec::RawCode raw_code) { - return raw_code_to_type_id(raw_code.unwrap()); +uint32_t GroupUnion::type_id(const RawCode raw_code) { + const auto code = raw_code.unwrap(); + const auto n = std::popcount(((code >> 1) ^ code) & 0x0249249249249249); + const auto n_2x1 = std::popcount((code >> 1) & ~code & 0x0249249249249249); + const auto n_1x1 = std::popcount((code >> 1) & code & 0x0249249249249249) - n - 3; + return to_type_id(n, n_2x1, n_1x1); } -klotski::cases::RangesUnion klotski::cases::GroupUnion::cases() const { - Ranges ranges {}; - +klotski::cases::RangesUnion GroupUnion::cases() const { auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id_]; + auto [s_a, s_b, s_c, s_d] = GROUP_UNION_CASES_NUM[type_id_]; - // int n = TYPE_ID_N_NUM[type_id_]; - // int n_2x1 = TYPE_ID_N_2x1_NUM[type_id_]; - // int n_1x1 = TYPE_ID_N_1x1_NUM[type_id_]; // TODO: cal from type_id + Ranges ranges {}; + ranges.reserve(BASIC_RANGES_NUM[type_id_]); ranges.spawn(n, n_2x1, n_1x1); - - // for (int i = 0; i < TYPE_ID_LIMIT; ++i) { - // ranges.spawn(TYPE_ID_N_NUM[i], TYPE_ID_N_2x1_NUM[i], TYPE_ID_N_1x1_NUM[i]); - // } - // std::stable_sort(ranges.begin(), ranges.end()); - - // for (auto &x : ranges) { - // x = klotski::range_reverse(x); - // } - ranges.reverse(); - // auto do_assert = [](uint32_t lhs, uint32_t rhs) { - // if (lhs != rhs) { - // std::cout << "error" << std::endl; - // } - // }; - RangesUnion cases; - - // cases[0x0].reserve(7815); - // cases[0x1].reserve(6795); - // cases[0x2].reserve(7815); - // - // cases[0x4].reserve(3525); - // cases[0x5].reserve(3465); - // cases[0x6].reserve(3525); - // - // cases[0x8].reserve(3525); - // cases[0x9].reserve(3465); - // cases[0xA].reserve(3525); - // - // cases[0xC].reserve(7815); - // cases[0xD].reserve(6795); - // cases[0xE].reserve(7815); - - auto [A, B, C, D] = GROUP_UNION_CASES_NUM[type_id_]; - - cases[0x0].reserve(A); - cases[0x1].reserve(B); - cases[0x2].reserve(A); - - cases[0x4].reserve(C); - cases[0x5].reserve(D); - cases[0x6].reserve(C); - - cases[0x8].reserve(C); - cases[0x9].reserve(D); - cases[0xA].reserve(C); - - cases[0xC].reserve(A); - cases[0xD].reserve(B); - cases[0xE].reserve(A); - - ranges.derive(0x0, cases[0x0]); - ranges.derive(0x1, cases[0x1]); - ranges.derive(0x2, cases[0x2]); - - ranges.derive(0x4, cases[0x4]); - ranges.derive(0x5, cases[0x5]); - ranges.derive(0x6, cases[0x6]); - - ranges.derive(0x8, cases[0x8]); - ranges.derive(0x9, cases[0x9]); - ranges.derive(0xA, cases[0xA]); - - ranges.derive(0xC, cases[0xC]); - ranges.derive(0xD, cases[0xD]); - ranges.derive(0xE, cases[0xE]); - - // uint32_t A = cases[0x0].size(); - // uint32_t B = cases[0x1].size(); - // uint32_t C = cases[0x4].size(); - // uint32_t D = cases[0x5].size(); - // - // do_assert(cases[0x2].size(), A); - // do_assert(cases[0x6].size(), C); - // do_assert(cases[0x8].size(), C); - // do_assert(cases[0x9].size(), D); - // do_assert(cases[0xA].size(), C); - // do_assert(cases[0xC].size(), A); - // do_assert(cases[0xD].size(), B); - // do_assert(cases[0xE].size(), A); - // - // std::cout << A << ", " << B << ", " << C << ", " << D << std::endl; - - // auto [A, B, C, D] = kk[type_id_]; - // do_assert(cases[0x0].size(), A); - // do_assert(cases[0x1].size(), B); - // do_assert(cases[0x2].size(), A); - // - // do_assert(cases[0x4].size(), C); - // do_assert(cases[0x5].size(), D); - // do_assert(cases[0x6].size(), C); - // - // do_assert(cases[0x8].size(), C); - // do_assert(cases[0x9].size(), D); - // do_assert(cases[0xA].size(), C); - // - // do_assert(cases[0xC].size(), A); - // do_assert(cases[0xD].size(), B); - // do_assert(cases[0xE].size(), A); - + cases[0x0].reserve(s_a); cases[0x1].reserve(s_b); cases[0x2].reserve(s_a); + cases[0x4].reserve(s_c); cases[0x5].reserve(s_d); cases[0x6].reserve(s_c); + cases[0x8].reserve(s_c); cases[0x9].reserve(s_d); cases[0xA].reserve(s_c); + cases[0xC].reserve(s_a); cases[0xD].reserve(s_b); cases[0xE].reserve(s_a); + + RANGE_DERIVE(0x0); RANGE_DERIVE(0x1); RANGE_DERIVE(0x2); + RANGE_DERIVE(0x4); RANGE_DERIVE(0x5); RANGE_DERIVE(0x6); + RANGE_DERIVE(0x8); RANGE_DERIVE(0x9); RANGE_DERIVE(0xA); + RANGE_DERIVE(0xC); RANGE_DERIVE(0xD); RANGE_DERIVE(0xE); return cases; } diff --git a/src/core/group/internal/group_union.inl b/src/core/group/internal/group_union.inl index 23e0dff..bc8e007 100644 --- a/src/core/group/internal/group_union.inl +++ b/src/core/group/internal/group_union.inl @@ -6,15 +6,15 @@ namespace klotski::cases { // ------------------------------------------------------------------------------------- // -inline uint32_t GroupUnion::unwrap() const { +inline constexpr uint32_t GroupUnion::unwrap() const { return type_id_; } -inline GroupUnion GroupUnion::unsafe_create(const uint32_t type_id) { +inline constexpr GroupUnion GroupUnion::unsafe_create(const uint32_t type_id) { return std::bit_cast(type_id); } -inline std::optional GroupUnion::create(const uint32_t type_id) { +inline constexpr std::optional GroupUnion::create(const uint32_t type_id) { if (type_id < TYPE_ID_LIMIT) { return unsafe_create(type_id); } @@ -23,19 +23,20 @@ inline std::optional GroupUnion::create(const uint32_t type_id) { // ------------------------------------------------------------------------------------- // -inline uint32_t GroupUnion::size() const { +inline constexpr uint32_t GroupUnion::size() const { return GROUP_UNION_SIZE[type_id_]; } -inline uint32_t GroupUnion::group_num() const { +inline constexpr uint32_t GroupUnion::group_num() const { return GROUP_NUM[type_id_]; } -inline uint32_t GroupUnion::max_group_size() const { +inline constexpr uint32_t GroupUnion::max_group_size() const { return MAX_GROUP_SIZE[type_id_]; } inline std::vector GroupUnion::groups() const { + // TODO: using `std::iota` helper std::vector groups; diff --git a/src/core/main.cc b/src/core/main.cc index cb74819..c5ef81b 100644 --- a/src/core/main.cc +++ b/src/core/main.cc @@ -41,6 +41,19 @@ int main() { // std::cout << ret[4].size() << std::endl; } + // auto group_union = GroupUnion::unsafe_create(169); + // for (auto group : group_union.groups()) { + // std::cout << group.type_id_ << ", " << group.group_id_ << std::endl; + // } + + constexpr auto gu = GroupUnion::unsafe_create(169); + constexpr auto gu_ = GroupUnion::create(169).value(); + // constexpr auto gu_ = GroupUnion::create(1169).value(); + constexpr auto k1 = gu.unwrap(); + constexpr auto k2 = gu.size(); + constexpr auto k3 = gu.group_num(); + constexpr auto k4 = gu.max_group_size(); + std::cerr << std::chrono::system_clock::now() - start << std::endl; // auto raw_code = RawCode::from_common_code(0x1A9BF0C00)->unwrap();