Browse Source

perf: construction of cases info list

master
Dnomd343 1 month ago
parent
commit
9448dbe0c0
  1. 2
      src/core/benchmark/group.cc
  2. 236
      src/core/group/internal/group_cases.cc

2
src/core/benchmark/group.cc

@ -443,6 +443,8 @@ static void FastObtainCode(benchmark::State &state) {
static void GroupCasesBuild(benchmark::State &state) {
klotski::codec::ShortCode::speed_up(true);
for (auto _ : state) {
GroupCases::build();
}

236
src/core/group/internal/group_cases.cc

@ -41,6 +41,27 @@ using klotski::group::GROUP_UNION_CASES_NUM;
#define RANGE_DERIVE_WITHOUT(HEAD) ranges.derive_without(HEAD, cases.ranges(HEAD), data.ranges(HEAD))
static RangesUnion cases_without(uint8_t type_id, const RangesUnion &data) {
Ranges ranges {};
ranges.reserve(BASIC_RANGES_NUM[type_id]);
const auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id];
ranges.spawn(n, n_2x1, n_1x1);
ranges.reverse();
RangesUnion cases;
const auto [na, nb, nc, nd] = GROUP_UNION_CASES_NUM[type_id];
RANGE_RESERVE(0x0, na); RANGE_RESERVE(0x1, nb); RANGE_RESERVE(0x2, na);
RANGE_RESERVE(0x4, nc); RANGE_RESERVE(0x5, nd); RANGE_RESERVE(0x6, nc);
RANGE_RESERVE(0x8, nc); RANGE_RESERVE(0x9, nd); RANGE_RESERVE(0xA, nc);
RANGE_RESERVE(0xC, na); RANGE_RESERVE(0xD, nb); RANGE_RESERVE(0xE, na);
RANGE_DERIVE_WITHOUT(0x0); RANGE_DERIVE_WITHOUT(0x1); RANGE_DERIVE_WITHOUT(0x2);
RANGE_DERIVE_WITHOUT(0x4); RANGE_DERIVE_WITHOUT(0x5); RANGE_DERIVE_WITHOUT(0x6);
RANGE_DERIVE_WITHOUT(0x8); RANGE_DERIVE_WITHOUT(0x9); RANGE_DERIVE_WITHOUT(0xA);
RANGE_DERIVE_WITHOUT(0xC); RANGE_DERIVE_WITHOUT(0xD); RANGE_DERIVE_WITHOUT(0xE);
return cases;
}
using MirrorType = Group::MirrorType;
struct case_info_t {
@ -51,6 +72,10 @@ struct case_info_t {
static_assert(sizeof(case_info_t) == 4);
bool operator==(const case_info_t &left, const case_info_t &right) {
return std::bit_cast<uint32_t>(left) == std::bit_cast<uint32_t>(right);
}
static std::array<RangesUnion, ALL_PATTERN_NUM * 4> *ru_data_array = nullptr;
static std::vector<case_info_t> *rev_data = nullptr;
@ -62,32 +87,33 @@ static std::vector<case_info_t> *rev_data = nullptr;
#define EMPLACE_INTO(RU, EXPR) \
EMPLACE_INTO_IMPL(RU, KLSK_UNIQUE(tmp), EXPR)
#define NO_MIRROR [](RawCode, auto) {}
#define CTR_M_IMPL spawn(code.to_diagonal_mirror());
#define CTR_MIRROR \
[](const RawCode code, auto spawn) { \
spawn(code.to_diagonal_mirror()); \
}
#define VRT_M_IMPL spawn(code.to_vertical_mirror());
#define HOR_MIRROR \
[](const RawCode code, auto spawn) { \
if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { \
spawn(m_hor); \
} \
#define HOR_M_IMPL \
const auto m_hor = code.to_horizontal_mirror(); \
if (m_hor != code) { \
spawn(m_hor); \
}
#define VRT_MIRROR \
[](const RawCode code, auto spawn) { \
spawn(code.to_vertical_mirror()); \
#define FULL_M_IMPL \
const auto m_vrt = code.to_vertical_mirror(); \
spawn(m_vrt); \
const auto m_hor = code.to_horizontal_mirror(); \
if (m_hor != code) { \
spawn(m_hor); \
spawn(m_vrt.to_horizontal_mirror()); \
}
#define FULL_MIRROR \
[](const RawCode code, auto spawn) { \
const auto m_vrt = code.to_vertical_mirror(); spawn(m_vrt); \
if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) { \
spawn(m_hor); spawn(m_vrt.to_horizontal_mirror()); \
} \
}
#define M_FUNC(EXPR) \
[](const RawCode code, auto spawn) { EXPR }
#define NO_MIRROR [](RawCode, auto) {}
#define CTR_MIRROR M_FUNC(CTR_M_IMPL)
#define HOR_MIRROR M_FUNC(HOR_M_IMPL)
#define VRT_MIRROR M_FUNC(VRT_M_IMPL)
#define FULL_MIRROR M_FUNC(FULL_M_IMPL)
template <typename MF, typename RF>
KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF add_mirror, RF release) {
@ -125,14 +151,6 @@ KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF ad
for (const auto code : mirrors) { release(code); }
}
// #define RELEASE_TO(A, B, C, D) \
// [&data_a, &data_b, &data_c, &data_d](const RawCode code) { \
// if constexpr(A == 'A') { EMPLACE_INTO(data_a, code); } \
// if constexpr(B == 'B') { EMPLACE_INTO(data_b, code.to_horizontal_mirror()); } \
// if constexpr(C == 'C') { EMPLACE_INTO(data_c, code.to_vertical_mirror()); } \
// if constexpr(D == 'D') { EMPLACE_INTO(data_d, code.to_diagonal_mirror()); } \
// }
template <MirrorType TYPE>
static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
RangesUnion &ga = data[0];
@ -144,7 +162,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
extend(seed, size, FULL_MIRROR, [&ga](const RawCode code) {
EMPLACE_INTO(ga, code);
});
// extend(seed, size, HOR_MIRROR, RELEASE_TO('A', 0, 0, 0));
}
if constexpr(TYPE == MirrorType::Horizontal) {
@ -152,7 +169,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
EMPLACE_INTO(ga, code);
EMPLACE_INTO(gc, code.to_vertical_mirror());
});
// extend(seed, size, HOR_MIRROR, RELEASE_TO('A', 0, 'C', 0));
}
if constexpr(TYPE == MirrorType::Centro) {
@ -160,7 +176,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
EMPLACE_INTO(ga, code);
EMPLACE_INTO(gb, code.to_horizontal_mirror());
});
// extend(seed, size, CTR_MIRROR, RELEASE_TO('A', 'B', 0, 0));
}
if constexpr(TYPE == MirrorType::Vertical) {
@ -168,7 +183,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
EMPLACE_INTO(ga, code);
EMPLACE_INTO(gb, code.to_horizontal_mirror());
});
// extend(seed, size, VRT_MIRROR, RELEASE_TO('A', 'B', 0, 0));
}
if constexpr(TYPE == MirrorType::Ordinary) {
@ -178,7 +192,6 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
EMPLACE_INTO(gc, code.to_vertical_mirror());
EMPLACE_INTO(gd, code.to_diagonal_mirror());
});
// extend(seed, size, NO_MIRROR, RELEASE_TO('A', 'B', 'C', 'D'));
}
for (const auto head : RangesUnion::Heads) {
@ -195,37 +208,20 @@ static void spawn_pattern(RawCode seed, size_t size, RangesUnion *data) {
}
}
// KLSK_INLINE static std::tuple<MirrorType, RawCode, size_t> get_info(uint8_t type_id, uint16_t pattern_id) {
// const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
// const auto mirror_type = static_cast<MirrorType>(PATTERN_DATA[flat_id] & 0b111);
// const auto seed_val = PATTERN_DATA[flat_id] >> 23;
// auto seed = CommonCode::unsafe_create(seed_val).to_raw_code();
// const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF;
// return {mirror_type, seed, size};
// }
static RangesUnion cases_without(uint8_t type_id, const RangesUnion &data) {
Ranges ranges {};
ranges.reserve(BASIC_RANGES_NUM[type_id]);
const auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id];
ranges.spawn(n, n_2x1, n_1x1);
ranges.reverse();
KLSK_INLINE static MirrorType mirror_type(uint8_t type_id, uint16_t pattern_id) {
const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
return static_cast<MirrorType>(PATTERN_DATA[flat_id] & 0b111);
}
RangesUnion cases;
const auto [na, nb, nc, nd] = GROUP_UNION_CASES_NUM[type_id];
RANGE_RESERVE(0x0, na); RANGE_RESERVE(0x1, nb); RANGE_RESERVE(0x2, na);
RANGE_RESERVE(0x4, nc); RANGE_RESERVE(0x5, nd); RANGE_RESERVE(0x6, nc);
RANGE_RESERVE(0x8, nc); RANGE_RESERVE(0x9, nd); RANGE_RESERVE(0xA, nc);
RANGE_RESERVE(0xC, na); RANGE_RESERVE(0xD, nb); RANGE_RESERVE(0xE, na);
KLSK_INLINE static RawCode pattern_seed(uint8_t type_id, uint16_t pattern_id) {
const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
const auto seed = PATTERN_DATA[flat_id] >> 23;
return CommonCode::unsafe_create(seed).to_raw_code();
}
RANGE_DERIVE_WITHOUT(0x0); RANGE_DERIVE_WITHOUT(0x1); RANGE_DERIVE_WITHOUT(0x2);
RANGE_DERIVE_WITHOUT(0x4); RANGE_DERIVE_WITHOUT(0x5); RANGE_DERIVE_WITHOUT(0x6);
RANGE_DERIVE_WITHOUT(0x8); RANGE_DERIVE_WITHOUT(0x9); RANGE_DERIVE_WITHOUT(0xA);
RANGE_DERIVE_WITHOUT(0xC); RANGE_DERIVE_WITHOUT(0xD); RANGE_DERIVE_WITHOUT(0xE);
return cases;
KLSK_INLINE static size_t group_size(uint8_t type_id, uint16_t pattern_id) {
const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
return (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF;
}
static void build_ru_arr(uint8_t type_id, RangesUnion *output) {
@ -235,36 +231,25 @@ static void build_ru_arr(uint8_t type_id, RangesUnion *output) {
return;
}
const bool ff_mode = (PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0;
// uint32_t pattern_id_begin = 0;
// if ((PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0) { // first pattern type `Full`
// pattern_id_begin = 1;
// }
const bool ff_mode = mirror_type(type_id, 0) == MirrorType::Full;
// TODO: black-list filter for some type_ids
for (uint16_t pattern_id = ff_mode ? 1 : 0; pattern_id < gu.pattern_num(); ++pattern_id) {
const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
const auto mirror_type = static_cast<MirrorType>(PATTERN_DATA[flat_id] & 0b111);
const auto seed_val = PATTERN_DATA[flat_id] >> 23;
auto seed = CommonCode::unsafe_create(seed_val).to_raw_code();
const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF;
// const auto [mirror_type, seed, size] = get_info(type_id, pattern_id);
const auto kk = output + pattern_id * 4;
if (mirror_type == MirrorType::Full) {
spawn_pattern<MirrorType::Full>(seed, size, kk);
} else if (mirror_type == MirrorType::Horizontal) {
spawn_pattern<MirrorType::Horizontal>(seed, size, kk);
} else if (mirror_type == MirrorType::Centro) {
spawn_pattern<MirrorType::Centro>(seed, size, kk);
} else if (mirror_type == MirrorType::Vertical) {
spawn_pattern<MirrorType::Vertical>(seed, size, kk);
} else if (mirror_type == MirrorType::Ordinary) {
spawn_pattern<MirrorType::Ordinary>(seed, size, kk);
const auto size = group_size(type_id, pattern_id);
const auto type = mirror_type(type_id, pattern_id);
const auto seed = pattern_seed(type_id, pattern_id);
auto *target = output + pattern_id * 4;
if (type == MirrorType::Full) {
spawn_pattern<MirrorType::Full>(seed, size, target);
} else if (type == MirrorType::Horizontal) {
spawn_pattern<MirrorType::Horizontal>(seed, size, target);
} else if (type == MirrorType::Centro) {
spawn_pattern<MirrorType::Centro>(seed, size, target);
} else if (type == MirrorType::Vertical) {
spawn_pattern<MirrorType::Vertical>(seed, size, target);
} else if (type == MirrorType::Ordinary) {
spawn_pattern<MirrorType::Ordinary>(seed, size, target);
}
}
@ -272,14 +257,12 @@ static void build_ru_arr(uint8_t type_id, RangesUnion *output) {
RangesUnion others;
// TODO: try to reserve
for (auto *ptr = output + 4; ptr < output + gu.pattern_num() * 4; ++ptr) {
others += *ptr;
for (auto *group = output + 4; group < output + gu.pattern_num() * 4; ++group) {
others += *group;
}
for (const auto head : RangesUnion::Heads) {
std::stable_sort(others.ranges(head).begin(), others.ranges(head).end());
}
// *output = gu.cases_without(others);
*output = cases_without(type_id, others);
}
}
@ -314,24 +297,52 @@ static std::vector<case_info_t> build_tmp_data() {
data.resize(ALL_CASES_NUM_);
ShortCode::speed_up(true);
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (auto group : group_union.groups()) {
uint32_t pattern_id = group.pattern_id();
auto toward_id = (uint32_t)group.toward();
// TODO: batch mirror base on pattern
auto codes = group.cases().codes();
for (uint32_t case_id = 0; case_id < codes.size(); ++case_id) {
auto short_code = codes[case_id].to_short_code();
data[short_code.unwrap()] = case_info_t {
.pattern_id = pattern_id,
.toward_id = toward_id,
.case_id = case_id,
};
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
const auto gu = GroupUnion::unsafe_create(type_id);
for (uint16_t pattern_id = 0; pattern_id < gu.pattern_num(); ++pattern_id) {
uint16_t flat_id = PATTERN_OFFSET[type_id] + pattern_id;
for (uint8_t toward_id = 0; toward_id < 4; ++toward_id) {
auto real_id = flat_id * 4 + toward_id;
auto group = (*ru_data_array)[real_id].codes();
for (uint32_t case_id = 0; case_id < group.size(); ++case_id) {
auto short_code = group[case_id].to_short_code();
data[short_code.unwrap()] = case_info_t {
.pattern_id = pattern_id,
.toward_id = toward_id,
.case_id = case_id,
};
}
}
}
}
// std::vector<case_info_t> data_verify;
// data_verify.resize(ALL_CASES_NUM_);
// for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
// auto group_union = GroupUnion::unsafe_create(type_id);
// for (auto group : group_union.groups()) {
// uint32_t pattern_id = group.pattern_id();
// auto toward_id = (uint32_t)group.toward();
// auto codes = group.cases().codes();
// for (uint32_t case_id = 0; case_id < codes.size(); ++case_id) {
// auto short_code = codes[case_id].to_short_code();
// data_verify[short_code.unwrap()] = case_info_t {
// .pattern_id = pattern_id,
// .toward_id = toward_id,
// .case_id = case_id,
// };
// }
// }
// }
// if (data != data_verify) {
// std::cout << "!!! error" << std::endl;
// }
return data;
}
@ -342,14 +353,11 @@ void GroupCases::build() {
// std::lock_guard guard {busy_};
// TODO: make `data` as class member
// static auto data_1 = build_ranges_unions();
// static auto data_2 = build_tmp_data();
// ru_data = &data_1;
// rev_data = &data_2;
// about 34.1ns
static auto data_array = build_ru_array();
ru_data_array = &data_array;
static auto data_2 = build_tmp_data();
rev_data = &data_2;
// KLSK_MEM_BARRIER;
// fast_ = true;

Loading…
Cancel
Save