From 73578fb04c139e0322fda9b44d73d427ce14e647 Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sat, 1 Mar 2025 17:47:25 +0800 Subject: [PATCH] perf: batch build of group cases --- src/core/benchmark/group.cc | 12 +- src/core/group/internal/group_cases.cc | 304 +++++++++++++++--------- src/core/group/internal/group_union.inl | 2 +- src/core/main.cc | 31 ++- 4 files changed, 224 insertions(+), 125 deletions(-) diff --git a/src/core/benchmark/group.cc b/src/core/benchmark/group.cc index bc541b7..fea7949 100644 --- a/src/core/benchmark/group.cc +++ b/src/core/benchmark/group.cc @@ -441,6 +441,14 @@ static void FastObtainCode(benchmark::State &state) { } } +static void GroupCasesBuild(benchmark::State &state) { + + for (auto _ : state) { + GroupCases::build(); + } + +} + // BENCHMARK(CommonCodeToTypeId)->Arg(8)->Arg(64)->Arg(256); // BENCHMARK(RawCodeToTypeId)->Arg(8)->Arg(64)->Arg(256); @@ -461,7 +469,9 @@ static void FastObtainCode(benchmark::State &state) { // BENCHMARK(GroupFromRawCode)->Unit(benchmark::kMillisecond); -BENCHMARK(FastObtainCode); +// BENCHMARK(FastObtainCode); + +BENCHMARK(GroupCasesBuild)->Unit(benchmark::kMillisecond); // BENCHMARK(IsVerticalMirror); // BENCHMARK(IsHorizontalMirror); diff --git a/src/core/group/internal/group_cases.cc b/src/core/group/internal/group_cases.cc index fa3b3d3..7b8c96a 100644 --- a/src/core/group/internal/group_cases.cc +++ b/src/core/group/internal/group_cases.cc @@ -38,153 +38,247 @@ struct case_info_t { static_assert(sizeof(case_info_t) == 4); -// struct ru_array_t { -// RangesUnion arr[4]; -// }; +static std::array *ru_data_array = nullptr; -// TODO: benchmark of `phmap` structure -// static std::vector> *ru_data = nullptr; // group_offset + toward +static std::vector *rev_data = nullptr; -static std::vector *ru_data_flat = nullptr; +#define RELEASE_INTO(RU) \ + [&RU](const RawCode raw_code) { \ + const auto code = raw_code.to_common_code().unwrap(); \ + RU.ranges(code >> 32).emplace_back(static_cast(code)); \ + } -static phmap::flat_hash_map *ru_data_map = nullptr; +#define NO_MIRROR \ + [](RawCode, auto) {} -// static std::vector> *ru_data_array = nullptr; -// static std::vector *ru_data_array = nullptr; -// static std::array, ALL_PATTERN_NUM> *ru_data_array = nullptr; -static std::array *ru_data_array = nullptr; +template +KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF add_mirror, RF release) { + std::vector queue, mirrors; + phmap::flat_hash_map cases; -static std::vector> *ru_data_vector = nullptr; + queue.reserve(reserve); mirrors.reserve(reserve); + cases.reserve(static_cast(static_cast(reserve) * 1.56)); // reduce load factor -static std::vector *rev_data = nullptr; + auto mover = MaskMover([&queue, &cases, &mirrors, add_mirror](RawCode code, uint64_t hint) { + if (const auto [iter, ret] = cases.try_emplace(code, hint); !ret) { + iter->second |= hint; // update hint + return; + } + queue.emplace_back(code); + add_mirror(code, [&cases, &mirrors](RawCode mirror) { + cases.emplace(mirror, 0); // without hint + mirrors.emplace_back(mirror); + }); + }); -std::vector> build_ranges_unions() { - std::vector> unions; - unions.reserve(ALL_GROUP_NUM); + queue.emplace_back(seed); + cases.emplace(seed, 0); + add_mirror(seed, [&mirrors, &cases](RawCode mirror) { + cases.emplace(mirror, 0); // without hint + mirrors.emplace_back(mirror); + }); - // TODO: add white list for single-group unions + size_t offset = 0; + while (offset != queue.size()) { + const auto curr = queue[offset++]; + mover.next_cases(curr, cases.find(curr)->second); + } + for (const auto code : queue) { release(code); } + for (const auto code : mirrors) { release(code); } +} - // TODO: helper with mirror +KLSK_NOINLINE static void extend_full_pattern(RawCode seed, size_t size, RangesUnion &a) { + const auto mirror_func = [](const RawCode code, auto spawn) { + const auto m_vrt = code.to_vertical_mirror(); + spawn(m_vrt); + if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { + spawn(m_hor); + spawn(m_vrt.to_horizontal_mirror()); + } + }; - for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - auto group_union = GroupUnion::unsafe_create(type_id); - for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { - std::vector groups; - for (auto group : group_union.groups()) { - if (group.pattern_id() == pattern_id) { - groups.emplace_back(group); - } - } + extend(seed, size, mirror_func, [&a](const RawCode raw_code) { + const auto code = raw_code.to_common_code().unwrap(); + a.ranges(code >> 32).emplace_back(static_cast(code)); + }); - std::vector tmp {4}; - for (auto group : groups) { - tmp[(int)group.toward()] = group.cases(); - } - unions.emplace_back(tmp); - } + for (auto head : RangesUnion::Heads) { + std::stable_sort(a.ranges(head).begin(), a.ranges(head).end()); } - return unions; + } -static std::vector build_ru_flat() { +KLSK_NOINLINE static void extend_hor_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &c) { + const auto mirror_func = [](const RawCode code, auto spawn) { + if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { + spawn(m_hor); + } + }; + extend(seed, size, mirror_func, [&a, &c](const RawCode raw_code) { + const auto code = raw_code.to_common_code().unwrap(); + a.ranges(code >> 32).emplace_back(static_cast(code)); - // phmap::flat_hash_map global_data; - // global_data.reserve(ALL_GROUP_NUM); + const auto code_ = raw_code.to_vertical_mirror().to_common_code().unwrap(); + c.ranges(code_ >> 32).emplace_back(static_cast(code_)); + }); - std::vector global_data; - global_data.resize(ALL_PATTERN_NUM * 4); + for (auto head : RangesUnion::Heads) { + std::stable_sort(a.ranges(head).begin(), a.ranges(head).end()); + std::stable_sort(c.ranges(head).begin(), c.ranges(head).end()); + } - for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { +} - const auto gu = GroupUnion::unsafe_create(type_id); - for (auto group : gu.groups()) { - // auto _ = group.cases(); +KLSK_NOINLINE static void extend_cen_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b) { + const auto mirror_func = [](const RawCode code, auto spawn) { + spawn(code.to_diagonal_mirror()); + }; + extend(seed, size, mirror_func, [&a, &b](const RawCode raw_code) { + const auto code = raw_code.to_common_code().unwrap(); + a.ranges(code >> 32).emplace_back(static_cast(code)); - // global_data.emplace(group, group.cases()); + const auto code_ = raw_code.to_horizontal_mirror().to_common_code().unwrap(); + b.ranges(code_ >> 32).emplace_back(static_cast(code_)); + }); - const auto index = (PATTERN_OFFSET[type_id] + group.pattern_id()) * 4 + (int)group.toward(); - global_data[index] = group.cases(); - } + for (auto head : RangesUnion::Heads) { + std::stable_sort(a.ranges(head).begin(), a.ranges(head).end()); + std::stable_sort(b.ranges(head).begin(), b.ranges(head).end()); } +} - // std::println("load factor: {}", global_data.load_factor()); +KLSK_NOINLINE static void extend_ver_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b) { + const auto mirror_func = [](const RawCode code, auto spawn) { + spawn(code.to_vertical_mirror()); + }; + extend(seed, size, mirror_func, [&a, &b](const RawCode raw_code) { + const auto code = raw_code.to_common_code().unwrap(); + a.ranges(code >> 32).emplace_back(static_cast(code)); + + const auto code_ = raw_code.to_horizontal_mirror().to_common_code().unwrap(); + b.ranges(code_ >> 32).emplace_back(static_cast(code_)); + }); - return global_data; + for (auto head : RangesUnion::Heads) { + std::stable_sort(a.ranges(head).begin(), a.ranges(head).end()); + std::stable_sort(b.ranges(head).begin(), b.ranges(head).end()); + } } -static phmap::flat_hash_map build_ru_map() { +KLSK_NOINLINE static void extend_ord_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b, RangesUnion &c, RangesUnion &d) { + extend(seed, size, NO_MIRROR, [&a, &b, &c, &d](const RawCode raw_code) { + const auto code = raw_code.to_common_code().unwrap(); + a.ranges(code >> 32).emplace_back(static_cast(code)); - phmap::flat_hash_map global_data; - global_data.reserve(ALL_GROUP_NUM); + const auto code_1 = raw_code.to_horizontal_mirror().to_common_code().unwrap(); + b.ranges(code_1 >> 32).emplace_back(static_cast(code_1)); - for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + const auto code_2 = raw_code.to_vertical_mirror().to_common_code().unwrap(); + c.ranges(code_2 >> 32).emplace_back(static_cast(code_2)); - const auto gu = GroupUnion::unsafe_create(type_id); - for (auto group : gu.groups()) { - global_data.emplace(group, group.cases()); - } - } + const auto code_3 = raw_code.to_diagonal_mirror().to_common_code().unwrap(); + d.ranges(code_3 >> 32).emplace_back(static_cast(code_3)); + }); - return global_data; + for (auto head : RangesUnion::Heads) { + std::stable_sort(a.ranges(head).begin(), a.ranges(head).end()); + std::stable_sort(b.ranges(head).begin(), b.ranges(head).end()); + std::stable_sort(c.ranges(head).begin(), c.ranges(head).end()); + std::stable_sort(d.ranges(head).begin(), d.ranges(head).end()); + } } -// std::array, ALL_PATTERN_NUM> build_ru_array() { std::array build_ru_array() { - - // std::vector> data; - // std::vector data; - // std::array, ALL_PATTERN_NUM> data; std::array data; - // data.reserve(ALL_PATTERN_NUM); - for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { auto group_union = GroupUnion::unsafe_create(type_id); - for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { - std::vector groups; - for (auto group : group_union.groups()) { - if (group.pattern_id() == pattern_id) { - groups.emplace_back(group); - } - } - auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; - for (auto group : groups) { - data[flat_id * 4 + (int)group.toward()] = group.cases(); - } + if (group_union.group_num() == 1) { // only single group + + // std::println("type_id = {}", type_id); + + data[PATTERN_OFFSET[type_id] * 4] = group_union.cases(); + continue; + } - // ru_array_t tmp {}; + uint32_t pattern_id_begin = 0; + if ((PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0) { // first pattern is `x` + pattern_id_begin = 1; + } + + for (uint32_t pattern_id = pattern_id_begin; pattern_id < group_union.pattern_num(); ++pattern_id) { + // std::vector groups; + // for (auto group : group_union.groups()) { + // if (group.pattern_id() == pattern_id) { + // groups.emplace_back(group); + // } + // } + // + // auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; // for (auto group : groups) { - // tmp.arr[(int)group.toward()] = group.cases(); + // data[flat_id * 4 + (int)group.toward()] = Group_cases(group); // } - // data.emplace_back(tmp); + + const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; + const auto mirror_type = static_cast(PATTERN_DATA[flat_id] & 0b111); + + const auto seed_val = PATTERN_DATA[flat_id] >> 23; + auto seed = CommonCode::unsafe_create(seed_val).to_raw_code(); + const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF; + + if (mirror_type == Group::MirrorType::Full) { + extend_full_pattern(seed, size, data[flat_id * 4]); + } else if (mirror_type == Group::MirrorType::Horizontal) { + extend_hor_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 2]); + } else if (mirror_type == Group::MirrorType::Centro) { + extend_cen_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1]); + } else if (mirror_type == Group::MirrorType::Vertical) { + extend_ver_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1]); + } else if (mirror_type == Group::MirrorType::Ordinary) { + extend_ord_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1], data[flat_id * 4 + 2], data[flat_id * 4 + 3]); + } + } - } - return data; -} -std::vector> build_ru_vector() { - std::vector> data; - data.reserve(ALL_GROUP_NUM); + if (pattern_id_begin == 1) { // first pattern is `x` - for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - auto group_union = GroupUnion::unsafe_create(type_id); - for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { - std::vector groups; - for (auto group : group_union.groups()) { - if (group.pattern_id() == pattern_id) { - groups.emplace_back(group); - } + RangesUnion others; + + size_t index_begin = (PATTERN_OFFSET[type_id] + 1) * 4; + size_t index_end = (PATTERN_OFFSET[type_id] + group_union.pattern_num()) * 4; + + for (size_t index = index_begin; index < index_end; ++index) { + others += data[index]; } - std::vector tmp {}; - tmp.resize(4); - for (auto group : groups) { - tmp[(int)group.toward()] = group.cases(); + for (auto head : RangesUnion::Heads) { + std::stable_sort(others.ranges(head).begin(), others.ranges(head).end()); } - data.emplace_back(tmp); + + data[PATTERN_OFFSET[type_id] * 4] = group_union.cases_without(others); + } + } + + + // verify + // for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + // const auto group_union = GroupUnion::unsafe_create(type_id); + // for (const auto group : group_union.groups()) { + // const auto flat_id = PATTERN_OFFSET[type_id] + group.pattern_id(); + // const auto index = flat_id * 4 + (int)group.toward(); + // + // const auto a = data[index].codes(); + // const auto b = group.cases().codes(); + // if (a != b) { + // std::cout << "!!! error: " << group << " | " << a.size() << " vs " << b.size() << std::endl; + // } + // } + // } + + return data; } @@ -226,22 +320,10 @@ void GroupCases::build() { // ru_data = &data_1; // rev_data = &data_2; - // about 35.7ns - // static auto data_flat = build_ru_flat(); - // ru_data_flat = &data_flat; - - // about 108ns - // static auto data_map = build_ru_map(); - // ru_data_map = &data_map; - // about 34.1ns static auto data_array = build_ru_array(); ru_data_array = &data_array; - // about 34.4ns - // static auto data_vector = build_ru_vector(); - // ru_data_vector = &data_vector; - // KLSK_MEM_BARRIER; // fast_ = true; } diff --git a/src/core/group/internal/group_union.inl b/src/core/group/internal/group_union.inl index a922b9c..497b7a4 100644 --- a/src/core/group/internal/group_union.inl +++ b/src/core/group/internal/group_union.inl @@ -18,7 +18,7 @@ constexpr uint32_t GroupUnion::group_num() const { return GROUP_NUM[type_id_]; } -constexpr uint32_t GroupUnion::pattern_num() const { +constexpr uint32_t GroupUnion::pattern_num() const { // TODO: why not using `uint16_t` return PATTERN_NUM[type_id_]; } diff --git a/src/core/main.cc b/src/core/main.cc index 87143d4..caf7f86 100644 --- a/src/core/main.cc +++ b/src/core/main.cc @@ -31,6 +31,7 @@ using klotski::codec::CommonCode; using klotski::group::GroupUnion; using klotski::group::Group; +using klotski::group::CaseInfo; using klotski::group::GroupCases; using klotski::group::GroupUnion; @@ -129,18 +130,24 @@ int main() { // perf-b -> ~1311ms // perf-c -> ~1272ms - for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - for (auto group: GroupUnion::unsafe_create(type_id).groups()) { - if (group.mirror_type() == Group::MirrorType::Full) { - // if (group.mirror_type() == Group::MirrorType::Horizontal) { - // if (group.mirror_type() == Group::MirrorType::Centro) { - // if (group.mirror_type() == Group::MirrorType::Vertical) { - // if (group.mirror_type() == Group::MirrorType::Ordinary) { - // std::println("{} ({})", group.to_string(), group.size()); - volatile auto kk = group.cases(); - } - } - } + // for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + // for (auto group: GroupUnion::unsafe_create(type_id).groups()) { + // // if (group.mirror_type() == Group::MirrorType::Full) { + // // if (group.mirror_type() == Group::MirrorType::Horizontal) { + // // if (group.mirror_type() == Group::MirrorType::Centro) { + // // if (group.mirror_type() == Group::MirrorType::Vertical) { + // // if (group.mirror_type() == Group::MirrorType::Ordinary) { + // // std::println("{} ({})", group.to_string(), group.size()); + // volatile auto kk = group.cases(); + // // } + // } + // } + + GroupCases::build(); + + constexpr auto group = Group::unsafe_create(169, 0, Group::Toward::C); + constexpr auto info = CaseInfo::unsafe_create(group, 7472); + std::cout << info << ": " << GroupCases::obtain_code(info) << std::endl; // constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A); // std::cout << group.to_string() << std::endl;