Browse Source

feat: group patterns test helper

legacy
Dnomd343 2 months ago
parent
commit
951d9c07bd
  1. 140
      src/core/main.cc
  2. 2
      src/core_test/CMakeLists.txt
  3. 31
      src/core_test/cases/group_pro.cc
  4. 25
      src/core_test/helper/group.h
  5. 288
      src/core_test/helper/internal/group.cc

140
src/core/main.cc

@ -34,124 +34,6 @@ using klotski::cases::GroupUnion;
using klotski::cases::TYPE_ID_LIMIT; using klotski::cases::TYPE_ID_LIMIT;
using klotski::codec::SHORT_CODE_LIMIT; using klotski::codec::SHORT_CODE_LIMIT;
struct Pattern {
enum class Mirror {
Full,
V,
HV,
H,
Common,
};
Pattern(CommonCode s) : seed(s) {}
Mirror mirror;
uint32_t size;
CommonCode seed;
uint32_t group_size;
std::array<std::vector<CommonCode>, 4> cases {};
};
std::vector<CommonCode> split_group(std::unordered_set<uint64_t> &codes, CommonCode origin) {
std::vector<CommonCode> group;
for (auto raw_code : Group::extend(origin.to_raw_code())) {
auto common_code = raw_code.to_common_code();
codes.erase(common_code.unwrap());
group.emplace_back(common_code);
}
std::sort(group.begin(), group.end());
return group;
}
std::vector<Pattern> split_patterns(const std::vector<CommonCode> &common_codes) {
std::unordered_set<uint64_t> codes;
for (auto code : common_codes) {
codes.emplace(code.unwrap());
}
std::vector<Pattern> patterns;
while (true) {
if (codes.empty()) {
break;
}
auto code_a = CommonCode::unsafe_create(*std::min_element(codes.begin(), codes.end()));
auto code_b = code_a.to_horizontal_mirror();
auto code_c = code_a.to_vertical_mirror();
auto code_d = code_b.to_vertical_mirror();
auto group_a = split_group(codes, code_a);
auto group_b = split_group(codes, code_b);
auto group_c = split_group(codes, code_c);
auto group_d = split_group(codes, code_d);
Pattern pattern {code_a};
pattern.cases[0] = group_a;
pattern.size = group_a.size();
pattern.group_size = group_a.size();
if (group_a.size() != group_b.size() || group_a.size() != group_c.size() || group_a.size() != group_d.size()) {
std::cout << "group size not match" << std::endl;
break;
}
if (group_a == group_b && group_a == group_c && group_a == group_d) {
pattern.mirror = Pattern::Mirror::Full;
} else if (group_a != group_b && group_a != group_c && group_a != group_d && \
group_b != group_c && group_b != group_d && group_c != group_d) {
pattern.mirror = Pattern::Mirror::Common;
pattern.cases[1] = group_b;
pattern.cases[2] = group_c;
pattern.cases[3] = group_d;
pattern.size *= 4;
} else if (group_a == group_b && group_c == group_d && group_a != group_c && group_b != group_d) {
pattern.mirror = Pattern::Mirror::V;
pattern.cases[1] = group_c;
pattern.size *= 2;
} else if (group_a == group_c && group_b == group_d && group_a != group_b && group_c != group_d) {
pattern.mirror = Pattern::Mirror::H;
pattern.cases[1] = group_b;
pattern.size *= 2;
} else if (group_a == group_d && group_b == group_c && group_a != group_b && group_c != group_d) {
pattern.mirror = Pattern::Mirror::HV;
pattern.cases[1] = group_b;
pattern.size *= 2;
} else {
std::cout << "unknown pattern" << std::endl;
break;
}
patterns.emplace_back(pattern);
}
std::sort(patterns.begin(), patterns.end(), [](const Pattern &lhs, const Pattern &rhs) {
if (lhs.size > rhs.size) {
return true;
}
if (lhs.size < rhs.size) {
return false;
}
if ((int)lhs.mirror < (int)rhs.mirror) {
return true;
}
if ((int)lhs.mirror > (int)rhs.mirror) {
return false;
}
if (lhs.seed < rhs.seed) {
return true;
}
if (lhs.seed > rhs.seed) {
return false;
}
std::cout << "pattern compare error" << std::endl;
return false;
});
return patterns;
}
int main() { int main() {
// const auto start = clock(); // const auto start = clock();
@ -171,28 +53,6 @@ int main() {
// std::cout << GroupCases::fast_obtain(CommonCode::unsafe_create(0x1A9BF0C00)) << std::endl; // std::cout << GroupCases::fast_obtain(CommonCode::unsafe_create(0x1A9BF0C00)) << std::endl;
// std::cout << GroupCases::fast_obtain(CommonCode::unsafe_create(0x4FEA13400)) << std::endl; // std::cout << GroupCases::fast_obtain(CommonCode::unsafe_create(0x4FEA13400)) << std::endl;
uint32_t flat_id = 0;
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto gu = GroupUnion::unsafe_create(type_id);
auto patterns = split_patterns(gu.cases().codes());
if (gu.pattern_num() != patterns.size()) {
std::cout << "pattern number error" << std::endl;
}
for (auto &pattern : patterns) {
if (pattern.group_size != klotski::cases::GROUP_PRO_SIZE[flat_id]) {
std::cout << "pattern group size error" << std::endl;
}
if ((int)pattern.mirror != klotski::cases::GROUP_PRO_TYPE[flat_id]) {
std::cout << "pattern mirror type error" << std::endl;
}
if (pattern.seed != klotski::cases::GROUP_PRO_SEED[flat_id]) {
std::cout << "pattern seed error" << std::endl;
}
++flat_id;
}
}
// std::cout << gu.pattern_num() << std::endl; // std::cout << gu.pattern_num() << std::endl;
// std::cout << gu.group_num() << std::endl; // std::cout << gu.group_num() << std::endl;
// std::cout << gu.groups_pro().size() << std::endl; // std::cout << gu.groups_pro().size() << std::endl;

2
src/core_test/CMakeLists.txt

@ -17,6 +17,7 @@ add_library(test_helper
helper/internal/concurrent.cc helper/internal/concurrent.cc
helper/internal/parallel.cc helper/internal/parallel.cc
helper/internal/hash.cc helper/internal/hash.cc
helper/internal/group.cc
) )
target_link_libraries(test_helper PRIVATE klotski_core bs::thread_pool md5sum::md5 xxHash::xxh3) target_link_libraries(test_helper PRIVATE klotski_core bs::thread_pool md5sum::md5 xxHash::xxh3)
@ -30,6 +31,7 @@ set(KLSK_TEST_CASES_SRC
cases/group_union.cc cases/group_union.cc
cases/group.cc cases/group.cc
cases/helper/group_impl.cc cases/helper/group_impl.cc
cases/group_pro.cc
) )
add_executable(test_klotski_cases ${KLSK_TEST_CASES_SRC}) add_executable(test_klotski_cases ${KLSK_TEST_CASES_SRC})

31
src/core_test/cases/group_pro.cc

@ -0,0 +1,31 @@
#include <gtest/gtest.h>
#include "helper/group.h"
#include "group/group.h"
using klotski::cases::GroupPro;
using klotski::cases::GroupUnion;
TEST(GroupPro, demo) {
std::cout << helper::group_union_num() << std::endl;
std::cout << helper::group_union_pattern_num(169) << std::endl;
std::cout << GroupUnion::unsafe_create(169).pattern_num() << std::endl;
std::cout << helper::group_union_group_num(169) << std::endl;
std::cout << GroupUnion::unsafe_create(169).group_num() << std::endl;
std::cout << (int)helper::pattern_mirror_type(169, 0) << std::endl;
std::cout << (int)GroupPro::unsafe_create(169, 0, 0).mirror_type() << std::endl;
std::cout << std::format("{}", helper::pattern_toward_list(169, 0)) << std::endl;
std::cout << (int)GroupUnion::unsafe_create(169).groups_pro()[0].mirror_toward() << std::endl;
std::cout << (int)GroupUnion::unsafe_create(169).groups_pro()[1].mirror_toward() << std::endl;
auto group_1 = GroupUnion::unsafe_create(169).groups_pro()[0];
EXPECT_EQ(group_1.cases().codes(), helper::group_cases(169, 0, group_1.mirror_toward()));
auto group_2 = GroupUnion::unsafe_create(169).groups_pro()[1];
EXPECT_EQ(group_2.cases().codes(), helper::group_cases(169, 0, group_2.mirror_toward()));
}

25
src/core_test/helper/group.h

@ -0,0 +1,25 @@
#pragma once
#include "common_code/common_code.h"
namespace helper {
using klotski::codec::CommonCode;
/// Get the type_id upper limit.
uint32_t group_union_num(); // TODO: remove it
/// Get cases contained in the specified type_id.
const std::vector<CommonCode>& group_union_cases(uint32_t type_id);
uint32_t group_union_pattern_num(uint32_t type_id);
uint32_t group_union_group_num(uint32_t type_id);
uint8_t pattern_mirror_type(uint32_t type_id, uint32_t pattern_id);
std::vector<uint8_t> pattern_toward_list(uint32_t type_id, uint32_t pattern_id);
const std::vector<CommonCode>& group_cases(uint32_t type_id, uint32_t pattern_id, uint8_t toward);
} // namespace helper

288
src/core_test/helper/internal/group.cc

@ -0,0 +1,288 @@
#include "helper/group.h"
#include "helper/block_num.h"
#include "group/group.h"
#include "all_cases/all_cases.h"
#include <iostream>
#include <algorithm>
#include <unordered_set>
using klotski::cases::AllCases;
using klotski::codec::CommonCode;
#define STATIC_DATA(name, impl) \
static const auto& name() { \
static auto data = [] {impl}(); \
return data; \
}
/// Filter cases with different type_id from AllCases.
STATIC_DATA(group_union_data, {
std::vector<std::vector<CommonCode>> codes;
codes.resize(helper::block_nums().size());
for (const auto code: AllCases::instance().fetch().codes()) {
const auto block_num = helper::cal_block_num(code.unwrap());
codes[to_type_id(block_num)].emplace_back(code);
}
return codes;
})
uint32_t helper::group_union_num() {
return group_union_data().size();
}
const std::vector<CommonCode>& helper::group_union_cases(const uint32_t type_id) {
if (type_id >= group_union_data().size()) {
std::abort();
}
return group_union_data()[type_id];
}
/// Extend ordered Group from the specified CommonCode seed.
static std::vector<CommonCode> extend_cases(CommonCode seed) {
// TODO: using inner build process -> only allow calling klotski::core
auto raw_codes = klotski::cases::Group::extend(seed.to_raw_code());
std::vector<CommonCode> common_codes {raw_codes.begin(), raw_codes.end()};
std::ranges::sort(common_codes.begin(), common_codes.end());
return common_codes;
}
struct Pattern {
enum class Mirror {
Full = 0,
V = 1,
HV = 2,
H = 3,
Common = 4,
};
Pattern(CommonCode s) : seed(s) {}
Mirror mirror;
uint32_t size;
CommonCode seed;
uint32_t group_size;
std::array<std::vector<CommonCode>, 4> cases {};
};
static std::vector<CommonCode> split_group(std::unordered_set<uint64_t> &codes, CommonCode origin) {
std::vector<CommonCode> group;
for (auto code : extend_cases(origin)) {
codes.erase(code.unwrap());
group.emplace_back(code);
}
std::sort(group.begin(), group.end());
return group;
}
static std::vector<Pattern> split_patterns(const std::vector<CommonCode> &common_codes) {
std::unordered_set<uint64_t> codes;
for (auto code : common_codes) {
codes.emplace(code.unwrap());
}
std::vector<Pattern> patterns;
while (true) {
if (codes.empty()) {
break;
}
auto code_a = CommonCode::unsafe_create(*std::min_element(codes.begin(), codes.end()));
auto code_b = code_a.to_horizontal_mirror();
auto code_c = code_a.to_vertical_mirror();
auto code_d = code_b.to_vertical_mirror();
auto group_a = split_group(codes, code_a);
auto group_b = split_group(codes, code_b);
auto group_c = split_group(codes, code_c);
auto group_d = split_group(codes, code_d);
Pattern pattern {code_a};
pattern.cases[0] = group_a;
pattern.size = group_a.size();
pattern.group_size = group_a.size();
if (group_a.size() != group_b.size() || group_a.size() != group_c.size() || group_a.size() != group_d.size()) {
std::cout << "group size not match" << std::endl;
break;
}
if (group_a == group_b && group_a == group_c && group_a == group_d) {
pattern.mirror = Pattern::Mirror::Full;
} else if (group_a != group_b && group_a != group_c && group_a != group_d && \
group_b != group_c && group_b != group_d && group_c != group_d) {
pattern.mirror = Pattern::Mirror::Common;
pattern.cases[1] = group_b;
pattern.cases[2] = group_c;
pattern.cases[3] = group_d;
pattern.size *= 4;
} else if (group_a == group_b && group_c == group_d && group_a != group_c && group_b != group_d) {
pattern.mirror = Pattern::Mirror::V;
pattern.cases[1] = group_c;
pattern.size *= 2;
} else if (group_a == group_c && group_b == group_d && group_a != group_b && group_c != group_d) {
pattern.mirror = Pattern::Mirror::H;
pattern.cases[1] = group_b;
pattern.size *= 2;
} else if (group_a == group_d && group_b == group_c && group_a != group_b && group_c != group_d) {
pattern.mirror = Pattern::Mirror::HV;
pattern.cases[1] = group_b;
pattern.size *= 2;
} else {
std::cout << "unknown pattern" << std::endl;
break;
}
patterns.emplace_back(pattern);
}
std::sort(patterns.begin(), patterns.end(), [](const Pattern &lhs, const Pattern &rhs) {
if (lhs.size > rhs.size) {
return true;
}
if (lhs.size < rhs.size) {
return false;
}
if ((int)lhs.mirror < (int)rhs.mirror) {
return true;
}
if ((int)lhs.mirror > (int)rhs.mirror) {
return false;
}
if (lhs.seed < rhs.seed) {
return true;
}
if (lhs.seed > rhs.seed) {
return false;
}
std::cout << "pattern compare error" << std::endl;
return false;
});
return patterns;
}
STATIC_DATA(pattern_data, {
std::vector<std::vector<Pattern>> patterns;
for (const auto &group_union : group_union_data()) {
patterns.emplace_back(split_patterns(group_union));
}
return patterns;
})
uint32_t helper::group_union_pattern_num(uint32_t type_id) {
if (type_id >= group_union_data().size()) {
std::abort();
}
return pattern_data()[type_id].size();
}
uint32_t helper::group_union_group_num(uint32_t type_id) {
if (type_id >= group_union_data().size()) {
std::abort();
}
uint32_t group_num {0};
for (const auto &pattern : pattern_data()[type_id]) {
switch (pattern.mirror) {
case Pattern::Mirror::Full:
++group_num;
break;
case Pattern::Mirror::V:
case Pattern::Mirror::HV:
case Pattern::Mirror::H:
group_num += 2;
break;
case Pattern::Mirror::Common:
group_num += 4;
break;
}
}
return group_num;
}
uint8_t helper::pattern_mirror_type(uint32_t type_id, uint32_t pattern_id) {
if (type_id >= group_union_data().size() || pattern_id >= pattern_data()[type_id].size()) {
std::abort();
}
const auto &pattern = pattern_data()[type_id][pattern_id];
return (uint8_t)pattern.mirror;
}
std::vector<uint8_t> helper::pattern_toward_list(uint32_t type_id, uint32_t pattern_id) {
if (type_id >= group_union_data().size() || pattern_id >= pattern_data()[type_id].size()) {
std::abort();
}
const auto &pattern = pattern_data()[type_id][pattern_id];
switch (pattern.mirror) {
case Pattern::Mirror::Full:
return {0};
case Pattern::Mirror::V:
return {0, 2};
case Pattern::Mirror::HV:
case Pattern::Mirror::H:
return {0, 1};
case Pattern::Mirror::Common:
return {0, 1, 2, 3};
}
}
const std::vector<CommonCode> &helper::group_cases(uint32_t type_id, uint32_t pattern_id, uint8_t toward) {
if (type_id >= group_union_data().size() || pattern_id >= pattern_data()[type_id].size()) {
std::abort();
}
const auto &pattern = pattern_data()[type_id][pattern_id];
if (toward == 0) {
return pattern.cases[0];
} else if (toward == 1) {
switch (pattern.mirror) {
case Pattern::Mirror::Full:
case Pattern::Mirror::V:
std::abort();
case Pattern::Mirror::HV:
case Pattern::Mirror::H:
case Pattern::Mirror::Common:
return pattern.cases[1];
}
} else if (toward == 2) {
switch (pattern.mirror) {
case Pattern::Mirror::Full:
case Pattern::Mirror::HV:
case Pattern::Mirror::H:
std::abort();
case Pattern::Mirror::V:
return pattern.cases[1];
case Pattern::Mirror::Common:
return pattern.cases[2];
}
} else if (toward == 3) {
switch (pattern.mirror) {
case Pattern::Mirror::Full:
case Pattern::Mirror::V:
case Pattern::Mirror::HV:
case Pattern::Mirror::H:
std::abort();
case Pattern::Mirror::Common:
return pattern.cases[3];
}
} else {
std::abort();
}
}
Loading…
Cancel
Save