From ddc16c3d45c7721f19c2b13e2b4a92c8a4cc962e Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sat, 1 Mar 2025 15:52:14 +0800 Subject: [PATCH] perf: pre-build of all group cases --- src/core/benchmark/group.cc | 80 ++++++------ src/core/group/internal/group_cases.cc | 171 +++++++++++++++++++++++-- 2 files changed, 201 insertions(+), 50 deletions(-) diff --git a/src/core/benchmark/group.cc b/src/core/benchmark/group.cc index 47fb2e2..bc541b7 100644 --- a/src/core/benchmark/group.cc +++ b/src/core/benchmark/group.cc @@ -134,24 +134,24 @@ static void GroupExtend(benchmark::State &state) { for (auto _ : state) { - // for (int type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - // for (auto group : GroupUnion::unsafe_create(type_id).groups()) { - // // if (group.mirror_type() == Group::MirrorType::Full) { - // // if (group.mirror_type() == Group::MirrorType::Horizontal) { - // // if (group.mirror_type() == Group::MirrorType::Centro) { - // // if (group.mirror_type() == Group::MirrorType::Vertical) { - // if (group.mirror_type() == Group::MirrorType::Ordinary) { - // // std::println("{} ({})", group.to_string(), group.size()); - // volatile auto kk = group.cases(); - // } - // } - // } + for (int type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + for (auto group : GroupUnion::unsafe_create(type_id).groups()) { + if (group.mirror_type() == Group::MirrorType::Full) { + // if (group.mirror_type() == Group::MirrorType::Horizontal) { + // if (group.mirror_type() == Group::MirrorType::Centro) { + // if (group.mirror_type() == Group::MirrorType::Vertical) { + // if (group.mirror_type() == Group::MirrorType::Ordinary) { + // std::println("{} ({})", group.to_string(), group.size()); + volatile auto kk = group.cases(); + } + } + } - constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A); + // constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A); // constexpr auto group = Group::unsafe_create(51, 0, Group::Toward::A); // constexpr auto group = Group::unsafe_create(98, 0, Group::Toward::A); - volatile auto kk = group.cases(); + // volatile auto kk = group.cases(); // for (auto group : groups) { // volatile auto tmp = group.cases(); @@ -370,7 +370,7 @@ static void ToHorizontalMirror(benchmark::State &state) { } static void FastObtainCode(benchmark::State &state) { - // GroupCases::build(); + GroupCases::build(); // std::vector infos; // for (auto code : common_code_samples(64)) { @@ -379,14 +379,14 @@ static void FastObtainCode(benchmark::State &state) { // infos.emplace_back(CaseInfo::unsafe_create()) const auto group = Group::unsafe_create(169, 0, Group::Toward::C); - const klotski::cases::RangesUnion data = group.cases(); - - std::array sizes {}; - size_t offset = 0; - for (int i = 0; i < 16; ++i) { - sizes[i] = offset; - offset += data.ranges(i).size(); - } + // const klotski::cases::RangesUnion data = group.cases(); + // + // std::array sizes {}; + // size_t offset = 0; + // for (int i = 0; i < 16; ++i) { + // sizes[i] = offset; + // offset += data.ranges(i).size(); + // } std::vector infos { CaseInfo::unsafe_create(group, 2631), @@ -415,24 +415,27 @@ static void FastObtainCode(benchmark::State &state) { for (auto info : infos) { /// about 35ns - auto &cases = data; - uint64_t head = 0; - auto case_id = info.case_id(); - for (;;) { - if (case_id >= cases.ranges(head).size()) { - case_id -= cases.ranges(head).size(); - ++head; - } else { - break; - } - } - auto range = cases.ranges(head)[case_id]; - volatile auto kk = CommonCode::unsafe_create(head << 32 | range); + // auto &cases = data; + // uint64_t head = 0; + // auto case_id = info.case_id(); + // for (;;) { + // if (case_id >= cases.ranges(head).size()) { + // case_id -= cases.ranges(head).size(); + // ++head; + // } else { + // break; + // } + // } + // auto range = cases.ranges(head)[case_id]; + // volatile auto kk = CommonCode::unsafe_create(head << 32 | range); /// about 117ns // uint64_t head = std::upper_bound(sizes.begin(), sizes.end(), info.case_id()) - sizes.begin() - 1; // uint32_t range = data[head][info.case_id() - sizes[head]]; // volatile auto kk = CommonCode::unsafe_create(head << 32 | range); + + volatile auto kk = GroupCases::fast_obtain_code(info); + } } @@ -441,7 +444,8 @@ static void FastObtainCode(benchmark::State &state) { // BENCHMARK(CommonCodeToTypeId)->Arg(8)->Arg(64)->Arg(256); // BENCHMARK(RawCodeToTypeId)->Arg(8)->Arg(64)->Arg(256); -BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond); +// BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond); +// BENCHMARK(GroupExtend)->Unit(benchmark::kMicrosecond); // BENCHMARK(FilterFromAllCases)->Unit(benchmark::kMillisecond); @@ -457,7 +461,7 @@ BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond); // BENCHMARK(GroupFromRawCode)->Unit(benchmark::kMillisecond); -// BENCHMARK(FastObtainCode); +BENCHMARK(FastObtainCode); // BENCHMARK(IsVerticalMirror); // BENCHMARK(IsHorizontalMirror); diff --git a/src/core/group/internal/group_cases.cc b/src/core/group/internal/group_cases.cc index 5488f08..fa3b3d3 100644 --- a/src/core/group/internal/group_cases.cc +++ b/src/core/group/internal/group_cases.cc @@ -26,6 +26,10 @@ using klotski::cases::ALL_CASES_NUM_; using klotski::group::GROUP_DATA; +using klotski::group::PATTERN_DATA; +using klotski::group::PATTERN_OFFSET; +using klotski::group::ALL_PATTERN_NUM; + struct case_info_t { uint32_t pattern_id : 10; uint32_t toward_id : 2; @@ -34,8 +38,24 @@ struct case_info_t { static_assert(sizeof(case_info_t) == 4); +// struct ru_array_t { +// RangesUnion arr[4]; +// }; + // TODO: benchmark of `phmap` structure -static std::vector> *ru_data = nullptr; // group_offset + toward +// static std::vector> *ru_data = nullptr; // group_offset + toward + +static std::vector *ru_data_flat = nullptr; + +static phmap::flat_hash_map *ru_data_map = nullptr; + +// static std::vector> *ru_data_array = nullptr; +// static std::vector *ru_data_array = nullptr; +// static std::array, ALL_PATTERN_NUM> *ru_data_array = nullptr; +static std::array *ru_data_array = nullptr; + +static std::vector> *ru_data_vector = nullptr; + static std::vector *rev_data = nullptr; std::vector> build_ranges_unions() { @@ -66,6 +86,108 @@ std::vector> build_ranges_unions() { return unions; } +static std::vector build_ru_flat() { + + // phmap::flat_hash_map global_data; + // global_data.reserve(ALL_GROUP_NUM); + + std::vector global_data; + global_data.resize(ALL_PATTERN_NUM * 4); + + for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + + const auto gu = GroupUnion::unsafe_create(type_id); + for (auto group : gu.groups()) { + // auto _ = group.cases(); + + // global_data.emplace(group, group.cases()); + + const auto index = (PATTERN_OFFSET[type_id] + group.pattern_id()) * 4 + (int)group.toward(); + global_data[index] = group.cases(); + } + } + + // std::println("load factor: {}", global_data.load_factor()); + + return global_data; +} + +static phmap::flat_hash_map build_ru_map() { + + phmap::flat_hash_map global_data; + global_data.reserve(ALL_GROUP_NUM); + + for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + + const auto gu = GroupUnion::unsafe_create(type_id); + for (auto group : gu.groups()) { + global_data.emplace(group, group.cases()); + } + } + + return global_data; +} + +// std::array, ALL_PATTERN_NUM> build_ru_array() { +std::array build_ru_array() { + + // std::vector> data; + // std::vector data; + // std::array, ALL_PATTERN_NUM> data; + std::array data; + + // data.reserve(ALL_PATTERN_NUM); + + for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + auto group_union = GroupUnion::unsafe_create(type_id); + for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { + std::vector groups; + for (auto group : group_union.groups()) { + if (group.pattern_id() == pattern_id) { + groups.emplace_back(group); + } + } + + auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; + for (auto group : groups) { + data[flat_id * 4 + (int)group.toward()] = group.cases(); + } + + // ru_array_t tmp {}; + // for (auto group : groups) { + // tmp.arr[(int)group.toward()] = group.cases(); + // } + // data.emplace_back(tmp); + } + } + return data; +} + +std::vector> build_ru_vector() { + std::vector> data; + data.reserve(ALL_GROUP_NUM); + + for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + auto group_union = GroupUnion::unsafe_create(type_id); + for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) { + std::vector groups; + for (auto group : group_union.groups()) { + if (group.pattern_id() == pattern_id) { + groups.emplace_back(group); + } + } + + std::vector tmp {}; + tmp.resize(4); + for (auto group : groups) { + tmp[(int)group.toward()] = group.cases(); + } + data.emplace_back(tmp); + } + } + return data; +} + static std::vector build_tmp_data() { std::vector data; data.resize(ALL_CASES_NUM_); @@ -93,19 +215,35 @@ static std::vector build_tmp_data() { } void GroupCases::build() { - if (fast_) { - return; - } - std::lock_guard guard {busy_}; + // if (fast_) { + // return; + // } + // std::lock_guard guard {busy_}; // TODO: make `data` as class member - static auto data_1 = build_ranges_unions(); - static auto data_2 = build_tmp_data(); - ru_data = &data_1; - rev_data = &data_2; + // static auto data_1 = build_ranges_unions(); + // static auto data_2 = build_tmp_data(); + // ru_data = &data_1; + // rev_data = &data_2; + + // about 35.7ns + // static auto data_flat = build_ru_flat(); + // ru_data_flat = &data_flat; + + // about 108ns + // static auto data_map = build_ru_map(); + // ru_data_map = &data_map; + + // about 34.1ns + static auto data_array = build_ru_array(); + ru_data_array = &data_array; + + // about 34.4ns + // static auto data_vector = build_ru_vector(); + // ru_data_vector = &data_vector; - KLSK_MEM_BARRIER; - fast_ = true; + // KLSK_MEM_BARRIER; + // fast_ = true; } void GroupCases::build_async(Executor &&executor, Notifier &&callback) { @@ -120,9 +258,18 @@ CommonCode GroupCases::fast_obtain_code(CaseInfo info) { auto flat_id = PATTERN_OFFSET[info.group().type_id()] + info.group().pattern_id(); - auto &cases = (*ru_data)[flat_id][(int)info.group().toward()]; + // auto &cases = (*ru_data)[flat_id][(int)info.group().toward()]; // TODO: make offset table for perf + // auto &cases = (*ru_data_flat)[flat_id * 4 + (int)info.group().toward()]; + + // auto &cases = (*ru_data_map)[info.group()]; + + // auto &cases = (*ru_data_array)[flat_id][(int)info.group().toward()]; + auto &cases = (*ru_data_array)[flat_id * 4 + (int)info.group().toward()]; + + // auto &cases = (*ru_data_vector)[flat_id][(int)info.group().toward()]; + uint64_t head = 0; auto case_id = info.case_id();