From 9448dbe0c07d8d550becf7bf3b56e9ce076d6ee6 Mon Sep 17 00:00:00 2001 From: Dnomd343 Date: Sun, 2 Mar 2025 18:39:19 +0800 Subject: [PATCH] perf: construction of cases info list --- src/core/benchmark/group.cc | 2 + src/core/group/internal/group_cases.cc | 236 +++++++++++++------------ 2 files changed, 124 insertions(+), 114 deletions(-) diff --git a/src/core/benchmark/group.cc b/src/core/benchmark/group.cc index fea7949..92c691a 100644 --- a/src/core/benchmark/group.cc +++ b/src/core/benchmark/group.cc @@ -443,6 +443,8 @@ static void FastObtainCode(benchmark::State &state) { static void GroupCasesBuild(benchmark::State &state) { + klotski::codec::ShortCode::speed_up(true); + for (auto _ : state) { GroupCases::build(); } diff --git a/src/core/group/internal/group_cases.cc b/src/core/group/internal/group_cases.cc index cc66817..be101ea 100644 --- a/src/core/group/internal/group_cases.cc +++ b/src/core/group/internal/group_cases.cc @@ -41,6 +41,27 @@ using klotski::group::GROUP_UNION_CASES_NUM; #define RANGE_DERIVE_WITHOUT(HEAD) ranges.derive_without(HEAD, cases.ranges(HEAD), data.ranges(HEAD)) +static RangesUnion cases_without(uint8_t type_id, const RangesUnion &data) { + Ranges ranges {}; + ranges.reserve(BASIC_RANGES_NUM[type_id]); + const auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id]; + ranges.spawn(n, n_2x1, n_1x1); + ranges.reverse(); + + RangesUnion cases; + const auto [na, nb, nc, nd] = GROUP_UNION_CASES_NUM[type_id]; + RANGE_RESERVE(0x0, na); RANGE_RESERVE(0x1, nb); RANGE_RESERVE(0x2, na); + RANGE_RESERVE(0x4, nc); RANGE_RESERVE(0x5, nd); RANGE_RESERVE(0x6, nc); + RANGE_RESERVE(0x8, nc); RANGE_RESERVE(0x9, nd); RANGE_RESERVE(0xA, nc); + RANGE_RESERVE(0xC, na); RANGE_RESERVE(0xD, nb); RANGE_RESERVE(0xE, na); + + RANGE_DERIVE_WITHOUT(0x0); RANGE_DERIVE_WITHOUT(0x1); RANGE_DERIVE_WITHOUT(0x2); + RANGE_DERIVE_WITHOUT(0x4); RANGE_DERIVE_WITHOUT(0x5); RANGE_DERIVE_WITHOUT(0x6); + RANGE_DERIVE_WITHOUT(0x8); RANGE_DERIVE_WITHOUT(0x9); RANGE_DERIVE_WITHOUT(0xA); + RANGE_DERIVE_WITHOUT(0xC); RANGE_DERIVE_WITHOUT(0xD); RANGE_DERIVE_WITHOUT(0xE); + return cases; +} + using MirrorType = Group::MirrorType; struct case_info_t { @@ -51,6 +72,10 @@ struct case_info_t { static_assert(sizeof(case_info_t) == 4); +bool operator==(const case_info_t &left, const case_info_t &right) { + return std::bit_cast(left) == std::bit_cast(right); +} + static std::array *ru_data_array = nullptr; static std::vector *rev_data = nullptr; @@ -62,32 +87,33 @@ static std::vector *rev_data = nullptr; #define EMPLACE_INTO(RU, EXPR) \ EMPLACE_INTO_IMPL(RU, KLSK_UNIQUE(tmp), EXPR) -#define NO_MIRROR [](RawCode, auto) {} +#define CTR_M_IMPL spawn(code.to_diagonal_mirror()); -#define CTR_MIRROR \ - [](const RawCode code, auto spawn) { \ - spawn(code.to_diagonal_mirror()); \ - } +#define VRT_M_IMPL spawn(code.to_vertical_mirror()); -#define HOR_MIRROR \ - [](const RawCode code, auto spawn) { \ - if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { \ - spawn(m_hor); \ - } \ +#define HOR_M_IMPL \ + const auto m_hor = code.to_horizontal_mirror(); \ + if (m_hor != code) { \ + spawn(m_hor); \ } -#define VRT_MIRROR \ - [](const RawCode code, auto spawn) { \ - spawn(code.to_vertical_mirror()); \ +#define FULL_M_IMPL \ + const auto m_vrt = code.to_vertical_mirror(); \ + spawn(m_vrt); \ + const auto m_hor = code.to_horizontal_mirror(); \ + if (m_hor != code) { \ + spawn(m_hor); \ + spawn(m_vrt.to_horizontal_mirror()); \ } -#define FULL_MIRROR \ - [](const RawCode code, auto spawn) { \ - const auto m_vrt = code.to_vertical_mirror(); spawn(m_vrt); \ - if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { \ - spawn(m_hor); spawn(m_vrt.to_horizontal_mirror()); \ - } \ - } +#define M_FUNC(EXPR) \ + [](const RawCode code, auto spawn) { EXPR } + +#define NO_MIRROR [](RawCode, auto) {} +#define CTR_MIRROR M_FUNC(CTR_M_IMPL) +#define HOR_MIRROR M_FUNC(HOR_M_IMPL) +#define VRT_MIRROR M_FUNC(VRT_M_IMPL) +#define FULL_MIRROR M_FUNC(FULL_M_IMPL) template KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF add_mirror, RF release) { @@ -125,14 +151,6 @@ KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF ad for (const auto code : mirrors) { release(code); } } -// #define RELEASE_TO(A, B, C, D) \ -// [&data_a, &data_b, &data_c, &data_d](const RawCode code) { \ -// if constexpr(A == 'A') { EMPLACE_INTO(data_a, code); } \ -// if constexpr(B == 'B') { EMPLACE_INTO(data_b, code.to_horizontal_mirror()); } \ -// if constexpr(C == 'C') { EMPLACE_INTO(data_c, code.to_vertical_mirror()); } \ -// if constexpr(D == 'D') { EMPLACE_INTO(data_d, code.to_diagonal_mirror()); } \ -// } - template static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { RangesUnion &ga = data[0]; @@ -144,7 +162,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { extend(seed, size, FULL_MIRROR, [&ga](const RawCode code) { EMPLACE_INTO(ga, code); }); - // extend(seed, size, HOR_MIRROR, RELEASE_TO('A', 0, 0, 0)); } if constexpr(TYPE == MirrorType::Horizontal) { @@ -152,7 +169,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { EMPLACE_INTO(ga, code); EMPLACE_INTO(gc, code.to_vertical_mirror()); }); - // extend(seed, size, HOR_MIRROR, RELEASE_TO('A', 0, 'C', 0)); } if constexpr(TYPE == MirrorType::Centro) { @@ -160,7 +176,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { EMPLACE_INTO(ga, code); EMPLACE_INTO(gb, code.to_horizontal_mirror()); }); - // extend(seed, size, CTR_MIRROR, RELEASE_TO('A', 'B', 0, 0)); } if constexpr(TYPE == MirrorType::Vertical) { @@ -168,7 +183,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { EMPLACE_INTO(ga, code); EMPLACE_INTO(gb, code.to_horizontal_mirror()); }); - // extend(seed, size, VRT_MIRROR, RELEASE_TO('A', 'B', 0, 0)); } if constexpr(TYPE == MirrorType::Ordinary) { @@ -178,7 +192,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { EMPLACE_INTO(gc, code.to_vertical_mirror()); EMPLACE_INTO(gd, code.to_diagonal_mirror()); }); - // extend(seed, size, NO_MIRROR, RELEASE_TO('A', 'B', 'C', 'D')); } for (const auto head : RangesUnion::Heads) { @@ -195,37 +208,20 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) { } } -// KLSK_INLINE static std::tuple get_info(uint8_t type_id, uint16_t pattern_id) { -// const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; - -// const auto mirror_type = static_cast(PATTERN_DATA[flat_id] & 0b111); - -// const auto seed_val = PATTERN_DATA[flat_id] >> 23; -// auto seed = CommonCode::unsafe_create(seed_val).to_raw_code(); -// const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF; - -// return {mirror_type, seed, size}; -// } - -static RangesUnion cases_without(uint8_t type_id, const RangesUnion &data) { - Ranges ranges {}; - ranges.reserve(BASIC_RANGES_NUM[type_id]); - const auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id]; - ranges.spawn(n, n_2x1, n_1x1); - ranges.reverse(); +KLSK_INLINE static MirrorType mirror_type(uint8_t type_id, uint16_t pattern_id) { + const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; + return static_cast(PATTERN_DATA[flat_id] & 0b111); +} - RangesUnion cases; - const auto [na, nb, nc, nd] = GROUP_UNION_CASES_NUM[type_id]; - RANGE_RESERVE(0x0, na); RANGE_RESERVE(0x1, nb); RANGE_RESERVE(0x2, na); - RANGE_RESERVE(0x4, nc); RANGE_RESERVE(0x5, nd); RANGE_RESERVE(0x6, nc); - RANGE_RESERVE(0x8, nc); RANGE_RESERVE(0x9, nd); RANGE_RESERVE(0xA, nc); - RANGE_RESERVE(0xC, na); RANGE_RESERVE(0xD, nb); RANGE_RESERVE(0xE, na); +KLSK_INLINE static RawCode pattern_seed(uint8_t type_id, uint16_t pattern_id) { + const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; + const auto seed = PATTERN_DATA[flat_id] >> 23; + return CommonCode::unsafe_create(seed).to_raw_code(); +} - RANGE_DERIVE_WITHOUT(0x0); RANGE_DERIVE_WITHOUT(0x1); RANGE_DERIVE_WITHOUT(0x2); - RANGE_DERIVE_WITHOUT(0x4); RANGE_DERIVE_WITHOUT(0x5); RANGE_DERIVE_WITHOUT(0x6); - RANGE_DERIVE_WITHOUT(0x8); RANGE_DERIVE_WITHOUT(0x9); RANGE_DERIVE_WITHOUT(0xA); - RANGE_DERIVE_WITHOUT(0xC); RANGE_DERIVE_WITHOUT(0xD); RANGE_DERIVE_WITHOUT(0xE); - return cases; +KLSK_INLINE static size_t group_size(uint8_t type_id, uint16_t pattern_id) { + const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; + return (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF; } static void build_ru_arr(uint8_t type_id, RangesUnion *output) { @@ -235,36 +231,25 @@ static void build_ru_arr(uint8_t type_id, RangesUnion *output) { return; } - const bool ff_mode = (PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0; - - // uint32_t pattern_id_begin = 0; - // if ((PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0) { // first pattern type `Full` - // pattern_id_begin = 1; - // } + const bool ff_mode = mirror_type(type_id, 0) == MirrorType::Full; + // TODO: black-list filter for some type_ids for (uint16_t pattern_id = ff_mode ? 1 : 0; pattern_id < gu.pattern_num(); ++pattern_id) { - - const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id; - const auto mirror_type = static_cast(PATTERN_DATA[flat_id] & 0b111); - - const auto seed_val = PATTERN_DATA[flat_id] >> 23; - auto seed = CommonCode::unsafe_create(seed_val).to_raw_code(); - const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF; - - // const auto [mirror_type, seed, size] = get_info(type_id, pattern_id); - - const auto kk = output + pattern_id * 4; - - if (mirror_type == MirrorType::Full) { - spawn_pattern(seed, size, kk); - } else if (mirror_type == MirrorType::Horizontal) { - spawn_pattern(seed, size, kk); - } else if (mirror_type == MirrorType::Centro) { - spawn_pattern(seed, size, kk); - } else if (mirror_type == MirrorType::Vertical) { - spawn_pattern(seed, size, kk); - } else if (mirror_type == MirrorType::Ordinary) { - spawn_pattern(seed, size, kk); + const auto size = group_size(type_id, pattern_id); + const auto type = mirror_type(type_id, pattern_id); + const auto seed = pattern_seed(type_id, pattern_id); + + auto *target = output + pattern_id * 4; + if (type == MirrorType::Full) { + spawn_pattern(seed, size, target); + } else if (type == MirrorType::Horizontal) { + spawn_pattern(seed, size, target); + } else if (type == MirrorType::Centro) { + spawn_pattern(seed, size, target); + } else if (type == MirrorType::Vertical) { + spawn_pattern(seed, size, target); + } else if (type == MirrorType::Ordinary) { + spawn_pattern(seed, size, target); } } @@ -272,14 +257,12 @@ static void build_ru_arr(uint8_t type_id, RangesUnion *output) { RangesUnion others; // TODO: try to reserve - for (auto *ptr = output + 4; ptr < output + gu.pattern_num() * 4; ++ptr) { - others += *ptr; + for (auto *group = output + 4; group < output + gu.pattern_num() * 4; ++group) { + others += *group; } - for (const auto head : RangesUnion::Heads) { std::stable_sort(others.ranges(head).begin(), others.ranges(head).end()); } - // *output = gu.cases_without(others); *output = cases_without(type_id, others); } } @@ -314,24 +297,52 @@ static std::vector build_tmp_data() { data.resize(ALL_CASES_NUM_); ShortCode::speed_up(true); - for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { - auto group_union = GroupUnion::unsafe_create(type_id); - for (auto group : group_union.groups()) { - uint32_t pattern_id = group.pattern_id(); - auto toward_id = (uint32_t)group.toward(); - - // TODO: batch mirror base on pattern - auto codes = group.cases().codes(); - for (uint32_t case_id = 0; case_id < codes.size(); ++case_id) { - auto short_code = codes[case_id].to_short_code(); - data[short_code.unwrap()] = case_info_t { - .pattern_id = pattern_id, - .toward_id = toward_id, - .case_id = case_id, - }; + for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + const auto gu = GroupUnion::unsafe_create(type_id); + for (uint16_t pattern_id = 0; pattern_id < gu.pattern_num(); ++pattern_id) { + + uint16_t flat_id = PATTERN_OFFSET[type_id] + pattern_id; + + for (uint8_t toward_id = 0; toward_id < 4; ++toward_id) { + auto real_id = flat_id * 4 + toward_id; + auto group = (*ru_data_array)[real_id].codes(); + + for (uint32_t case_id = 0; case_id < group.size(); ++case_id) { + auto short_code = group[case_id].to_short_code(); + data[short_code.unwrap()] = case_info_t { + .pattern_id = pattern_id, + .toward_id = toward_id, + .case_id = case_id, + }; + } } } } + + + // std::vector data_verify; + // data_verify.resize(ALL_CASES_NUM_); + // for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) { + // auto group_union = GroupUnion::unsafe_create(type_id); + // for (auto group : group_union.groups()) { + // uint32_t pattern_id = group.pattern_id(); + // auto toward_id = (uint32_t)group.toward(); + // auto codes = group.cases().codes(); + // for (uint32_t case_id = 0; case_id < codes.size(); ++case_id) { + // auto short_code = codes[case_id].to_short_code(); + // data_verify[short_code.unwrap()] = case_info_t { + // .pattern_id = pattern_id, + // .toward_id = toward_id, + // .case_id = case_id, + // }; + // } + // } + // } + // if (data != data_verify) { + // std::cout << "!!! error" << std::endl; + // } + + return data; } @@ -342,14 +353,11 @@ void GroupCases::build() { // std::lock_guard guard {busy_}; // TODO: make `data` as class member - // static auto data_1 = build_ranges_unions(); - // static auto data_2 = build_tmp_data(); - // ru_data = &data_1; - // rev_data = &data_2; - // about 34.1ns static auto data_array = build_ru_array(); ru_data_array = &data_array; + static auto data_2 = build_tmp_data(); + rev_data = &data_2; // KLSK_MEM_BARRIER; // fast_ = true;