From e65109288df08be68f174f1d08008d640e3632a9 Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sat, 19 Oct 2024 16:05:38 +0800 Subject: [PATCH] update: enhance group union and group interfaces --- src/core/group/group.h | 62 +++++++++++++------ src/core/group/internal/group_cases_pro.cc | 4 +- src/core/group/internal/group_union.inl | 35 +++++++++-- src/core_test/CMakeLists.txt | 2 +- .../cases/{group.cc => group_legacy.cc} | 14 ++--- src/core_test/cases/group_pro.cc | 22 +++---- src/core_test/cases/group_union.cc | 8 +-- 7 files changed, 99 insertions(+), 48 deletions(-) rename src/core_test/cases/{group.cc => group_legacy.cc} (79%) diff --git a/src/core/group/group.h b/src/core/group/group.h index 5435035..65c7d0e 100644 --- a/src/core/group/group.h +++ b/src/core/group/group.h @@ -81,11 +81,12 @@ typedef std::vector CommonCodes; class Group; -// TODO: add constexpr class GroupUnion { public: GroupUnion() = delete; + using Groups = std::vector; + // ------------------------------------------------------------------------------------- // /// Get the original type id. @@ -105,17 +106,11 @@ public: /// Get the number of groups contained. [[nodiscard]] constexpr uint32_t group_num() const; - /// Get the upper limit of the group size. - [[nodiscard]] constexpr uint32_t max_group_size() const; - - // ------------------------------------------------------------------------------------- // - /// TODO: new interface - + /// Get the number of patterns contained. [[nodiscard]] constexpr uint32_t pattern_num() const; - [[nodiscard]] std::vector groups_pro() const; - - // TODO: get target pattern_id + /// Get the upper limit of the group size. + [[nodiscard]] constexpr uint32_t max_group_size() const; // ------------------------------------------------------------------------------------- // @@ -123,21 +118,26 @@ public: [[nodiscard]] RangesUnion cases() const; /// Get all groups under the current type id. -// [[nodiscard]] std::vector groups() const; + [[nodiscard]] constexpr Groups groups() const; - /// Get the group instance with the specified group id. -// [[nodiscard]] std::optional group(uint32_t group_id) const; + /// Get the group instance with the specified pattern id. + [[nodiscard]] constexpr std::optional groups(uint32_t pattern_id) const; // ------------------------------------------------------------------------------------- // /// Create GroupUnion from RawCode. - static GroupUnion from_raw_code(codec::RawCode raw_code); + static constexpr GroupUnion from_raw_code(codec::RawCode raw_code); /// Create GroupUnion from ShortCode. - static GroupUnion from_short_code(codec::ShortCode short_code); + static constexpr GroupUnion from_short_code(codec::ShortCode short_code); /// Create GroupUnion from CommonCode. - static GroupUnion from_common_code(codec::CommonCode common_code); + static constexpr GroupUnion from_common_code(codec::CommonCode common_code); + + // ------------------------------------------------------------------------------------- // + + /// Compare the type_id values of two GroupUnion. + friend constexpr auto operator==(const GroupUnion &lhs, const GroupUnion &rhs); // ------------------------------------------------------------------------------------- // @@ -155,8 +155,6 @@ private: // ------------------------------------------------------------------------------------- // }; -// TODO: add `==` and `std::hash` - // TODO: add debug output class Group { @@ -239,6 +237,11 @@ public: // ------------------------------------------------------------------------------------- // + /// Compare the internal values of two Group. + friend constexpr auto operator==(const Group &lhs, const Group &rhs); + + // ------------------------------------------------------------------------------------- // + private: uint32_t type_id_; Toward toward_; @@ -351,3 +354,26 @@ public: #include "internal/group_union.inl" #include "internal/group_cases.inl" #include "internal/group.inl" + +// ----------------------------------------------------------------------------------------- // + +namespace std { + +template <> +struct std::hash { + constexpr std::size_t operator()(const klotski::cases::Group &g) const noexcept { + // TODO: perf hash alg + return std::hash{}(g.type_id() ^ g.pattern_id() ^ (int)g.toward()); + } +}; + +template <> +struct std::hash { + constexpr std::size_t operator()(const klotski::cases::GroupUnion &gu) const noexcept { + return std::hash{}(gu.unwrap()); + } +}; + +} // namespace std + +// ----------------------------------------------------------------------------------------- // diff --git a/src/core/group/internal/group_cases_pro.cc b/src/core/group/internal/group_cases_pro.cc index d3df3c2..2061726 100644 --- a/src/core/group/internal/group_cases_pro.cc +++ b/src/core/group/internal/group_cases_pro.cc @@ -37,7 +37,7 @@ std::vector> build_ranges_unions() { auto group_union = GroupUnion::unsafe_create(type_id); for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { std::vector groups; - for (auto group : group_union.groups_pro()) { + for (auto group : group_union.groups()) { if (group.pattern_id() == pattern_id) { groups.emplace_back(group); } @@ -60,7 +60,7 @@ static std::vector build_tmp_data() { for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { auto group_union = GroupUnion::unsafe_create(type_id); - for (auto group : group_union.groups_pro()) { + for (auto group : group_union.groups()) { uint32_t pattern_id = group.pattern_id(); auto toward_id = (uint32_t)group.toward(); diff --git a/src/core/group/internal/group_union.inl b/src/core/group/internal/group_union.inl index 1a3ccb2..c457977 100644 --- a/src/core/group/internal/group_union.inl +++ b/src/core/group/internal/group_union.inl @@ -54,13 +54,12 @@ constexpr uint32_t GroupUnion::max_group_size() const { //} // ----------------------------------------------------------------------------------------- // -// TODO: new interface constexpr uint32_t GroupUnion::pattern_num() const { return PATTERN_NUM[type_id_]; } -inline std::vector GroupUnion::groups_pro() const { +constexpr std::vector GroupUnion::groups() const { std::vector groups; groups.reserve(group_num()); for (uint32_t pattern_id = 0; pattern_id < pattern_num(); ++pattern_id) { @@ -86,17 +85,43 @@ inline std::vector GroupUnion::groups_pro() const { return groups; } +constexpr std::optional> GroupUnion::groups(uint32_t pattern_id) const { + if (pattern_id >= pattern_num()) { + return std::nullopt; + } + std::vector groups; + auto group = Group::unsafe_create(type_id_, pattern_id, Group::Toward::A); + groups.emplace_back(group); + switch (group.mirror_type()) { + case Group::MirrorType::Full: + break; + case Group::MirrorType::Horizontal: + groups.emplace_back(Group::unsafe_create(type_id_, pattern_id, Group::Toward::C)); + break; + case Group::MirrorType::Centro: + case Group::MirrorType::Vertical: + groups.emplace_back(Group::unsafe_create(type_id_, pattern_id, Group::Toward::B)); + break; + case Group::MirrorType::Ordinary: + groups.emplace_back(Group::unsafe_create(type_id_, pattern_id, Group::Toward::B)); + groups.emplace_back(Group::unsafe_create(type_id_, pattern_id, Group::Toward::C)); + groups.emplace_back(Group::unsafe_create(type_id_, pattern_id, Group::Toward::D)); + break; + } + return groups; +} + // ----------------------------------------------------------------------------------------- // -inline GroupUnion GroupUnion::from_raw_code(const codec::RawCode raw_code) { +constexpr GroupUnion GroupUnion::from_raw_code(const codec::RawCode raw_code) { return unsafe_create(type_id(raw_code)); } -inline GroupUnion GroupUnion::from_short_code(const codec::ShortCode short_code) { +constexpr GroupUnion GroupUnion::from_short_code(const codec::ShortCode short_code) { return from_common_code(short_code.to_common_code()); } -inline GroupUnion GroupUnion::from_common_code(const codec::CommonCode common_code) { +constexpr GroupUnion GroupUnion::from_common_code(const codec::CommonCode common_code) { return unsafe_create(type_id(common_code)); } diff --git a/src/core_test/CMakeLists.txt b/src/core_test/CMakeLists.txt index 98e2ee6..5a523f5 100644 --- a/src/core_test/CMakeLists.txt +++ b/src/core_test/CMakeLists.txt @@ -29,7 +29,7 @@ set(KLSK_TEST_CASES_SRC cases/basic_ranges.cc cases/all_cases.cc cases/group_union.cc - cases/group.cc + cases/group_legacy.cc cases/helper/group_impl.cc cases/group_pro.cc ) diff --git a/src/core_test/cases/group.cc b/src/core_test/cases/group_legacy.cc similarity index 79% rename from src/core_test/cases/group.cc rename to src/core_test/cases/group_legacy.cc index 38a4d3c..8975891 100644 --- a/src/core_test/cases/group.cc +++ b/src/core_test/cases/group_legacy.cc @@ -12,10 +12,10 @@ using klotski::cases::TYPE_ID_LIMIT; // TODO: hash check for every group -TEST(Group, cases) { - - for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - auto group_union = GroupUnion::unsafe_create(type_id); +//TEST(Group, cases) { +// +// for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { +// auto group_union = GroupUnion::unsafe_create(type_id); // for (auto group : group_union.groups()) { // @@ -27,9 +27,9 @@ TEST(Group, cases) { // EXPECT_EQ(codes.size(), group.size()); // // } - } - -} +// } +// +//} // TODO: test from_raw_code / from_short_code / from_common_code diff --git a/src/core_test/cases/group_pro.cc b/src/core_test/cases/group_pro.cc index b941fcb..358c42d 100644 --- a/src/core_test/cases/group_pro.cc +++ b/src/core_test/cases/group_pro.cc @@ -10,7 +10,7 @@ using klotski::cases::Group; using klotski::cases::GroupUnion; -TEST(GroupPro, demo) { +TEST(Group, demo) { std::cout << helper::group_union_num() << std::endl; @@ -24,19 +24,19 @@ TEST(GroupPro, demo) { std::cout << (int)Group::unsafe_create(169, 0, Group::Toward::A).mirror_type() << std::endl; std::cout << std::format("{}", helper::pattern_toward_list(169, 0)) << std::endl; - std::cout << (int)GroupUnion::unsafe_create(169).groups_pro()[0].toward() << std::endl; - std::cout << (int)GroupUnion::unsafe_create(169).groups_pro()[1].toward() << std::endl; + std::cout << (int)GroupUnion::unsafe_create(169).groups()[0].toward() << std::endl; + std::cout << (int)GroupUnion::unsafe_create(169).groups()[1].toward() << std::endl; - auto group_1 = GroupUnion::unsafe_create(169).groups_pro()[0]; + auto group_1 = GroupUnion::unsafe_create(169).groups()[0]; EXPECT_EQ(group_1.cases().codes(), helper::group_cases(169, 0, (uint32_t)group_1.toward())); - auto group_2 = GroupUnion::unsafe_create(169).groups_pro()[1]; + auto group_2 = GroupUnion::unsafe_create(169).groups()[1]; EXPECT_EQ(group_2.cases().codes(), helper::group_cases(169, 0, (uint32_t)group_2.toward())); } -TEST(GroupPro, cases) { +TEST(Group, cases) { GROUP_UNION_PARALLEL({ - for (auto group : group_union.groups_pro()) { + for (auto group : group_union.groups()) { const auto &cases = helper::group_cases(group.type_id(), group.pattern_id(), (uint32_t)group.toward()); EXPECT_EQ(group.size(), cases.size()); EXPECT_EQ(group.cases().codes(), cases); @@ -56,9 +56,9 @@ TEST(GroupPro, cases) { }); } -TEST(GroupPro, v_mirror) { +TEST(Group, v_mirror) { GROUP_UNION_PARALLEL({ - for (auto group : group_union.groups_pro()) { + for (auto group : group_union.groups()) { auto g = group.to_vertical_mirror(); EXPECT_EQ(group.type_id(), g.type_id()); EXPECT_EQ(group.pattern_id(), g.pattern_id()); @@ -96,9 +96,9 @@ TEST(GroupPro, v_mirror) { }); } -TEST(GroupPro, h_mirror) { +TEST(Group, h_mirror) { GROUP_UNION_PARALLEL({ - for (auto group : group_union.groups_pro()) { + for (auto group : group_union.groups()) { auto g = group.to_horizontal_mirror(); EXPECT_EQ(group.type_id(), g.type_id()); EXPECT_EQ(group.pattern_id(), g.pattern_id()); diff --git a/src/core_test/cases/group_union.cc b/src/core_test/cases/group_union.cc index 9caf81f..63a0979 100644 --- a/src/core_test/cases/group_union.cc +++ b/src/core_test/cases/group_union.cc @@ -41,7 +41,7 @@ TEST(GroupUnion, basic) { EXPECT_FALSE(GroupUnion::create(TYPE_ID_LIMIT).has_value()); GROUP_UNION_PARALLEL({ - const auto groups = group_union.groups_pro(); + const auto groups = group_union.groups(); EXPECT_EQ(groups.size(), group_union.group_num()); auto get_type_id = [](const auto g) { return g.type_id(); }; @@ -87,7 +87,7 @@ TEST(GroupUnion, values) { EXPECT_EQ(group_union.group_num(), group_num(type_id)); auto get_group_size = [](auto g) { return g.size(); }; - const auto sizes = group_union.groups_pro() | std::views::transform(get_group_size); + const auto sizes = group_union.groups() | std::views::transform(get_group_size); EXPECT_EQ(group_union.max_group_size(), *std::ranges::max_element(sizes)); }); } @@ -103,10 +103,10 @@ TEST(GroupUnion, values_pro) { EXPECT_EQ(group_union.pattern_num(), helper::group_union_pattern_num(type_id)); auto get_group_size = [](auto g) { return g.size(); }; - const auto sizes = group_union.groups_pro() | std::views::transform(get_group_size); + const auto sizes = group_union.groups() | std::views::transform(get_group_size); EXPECT_EQ(group_union.max_group_size(), *std::ranges::max_element(sizes)); - auto groups = group_union.groups_pro(); + auto groups = group_union.groups(); for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { std::vector towards; for (auto group : groups) {