Browse Source

perf: batch build of group cases

master
Dnomd343 1 month ago
parent
commit
73578fb04c
  1. 12
      src/core/benchmark/group.cc
  2. 304
      src/core/group/internal/group_cases.cc
  3. 2
      src/core/group/internal/group_union.inl
  4. 31
      src/core/main.cc

12
src/core/benchmark/group.cc

@ -441,6 +441,14 @@ static void FastObtainCode(benchmark::State &state) {
}
}
static void GroupCasesBuild(benchmark::State &state) {
for (auto _ : state) {
GroupCases::build();
}
}
// BENCHMARK(CommonCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
// BENCHMARK(RawCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
@ -461,7 +469,9 @@ static void FastObtainCode(benchmark::State &state) {
// BENCHMARK(GroupFromRawCode)->Unit(benchmark::kMillisecond);
BENCHMARK(FastObtainCode);
// BENCHMARK(FastObtainCode);
BENCHMARK(GroupCasesBuild)->Unit(benchmark::kMillisecond);
// BENCHMARK(IsVerticalMirror);
// BENCHMARK(IsHorizontalMirror);

304
src/core/group/internal/group_cases.cc

@ -38,153 +38,247 @@ struct case_info_t {
static_assert(sizeof(case_info_t) == 4);
// struct ru_array_t {
// RangesUnion arr[4];
// };
static std::array<RangesUnion, ALL_PATTERN_NUM * 4> *ru_data_array = nullptr;
// TODO: benchmark of `phmap<Group>` structure
// static std::vector<std::vector<RangesUnion>> *ru_data = nullptr; // group_offset + toward
static std::vector<case_info_t> *rev_data = nullptr;
static std::vector<RangesUnion> *ru_data_flat = nullptr;
#define RELEASE_INTO(RU) \
[&RU](const RawCode raw_code) { \
const auto code = raw_code.to_common_code().unwrap(); \
RU.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code)); \
}
static phmap::flat_hash_map<Group, RangesUnion> *ru_data_map = nullptr;
#define NO_MIRROR \
[](RawCode, auto) {}
// static std::vector<std::array<RangesUnion, 4>> *ru_data_array = nullptr;
// static std::vector<ru_array_t> *ru_data_array = nullptr;
// static std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> *ru_data_array = nullptr;
static std::array<RangesUnion, ALL_PATTERN_NUM * 4> *ru_data_array = nullptr;
template <typename MF, typename RF>
KLSK_NOINLINE static void extend(const RawCode seed, const size_t reserve, MF add_mirror, RF release) {
std::vector<RawCode> queue, mirrors;
phmap::flat_hash_map<RawCode, uint64_t> cases;
static std::vector<std::vector<RangesUnion>> *ru_data_vector = nullptr;
queue.reserve(reserve); mirrors.reserve(reserve);
cases.reserve(static_cast<size_t>(static_cast<double>(reserve) * 1.56)); // reduce load factor
static std::vector<case_info_t> *rev_data = nullptr;
auto mover = MaskMover([&queue, &cases, &mirrors, add_mirror](RawCode code, uint64_t hint) {
if (const auto [iter, ret] = cases.try_emplace(code, hint); !ret) {
iter->second |= hint; // update hint
return;
}
queue.emplace_back(code);
add_mirror(code, [&cases, &mirrors](RawCode mirror) {
cases.emplace(mirror, 0); // without hint
mirrors.emplace_back(mirror);
});
});
std::vector<std::vector<RangesUnion>> build_ranges_unions() {
std::vector<std::vector<RangesUnion>> unions;
unions.reserve(ALL_GROUP_NUM);
queue.emplace_back(seed);
cases.emplace(seed, 0);
add_mirror(seed, [&mirrors, &cases](RawCode mirror) {
cases.emplace(mirror, 0); // without hint
mirrors.emplace_back(mirror);
});
// TODO: add white list for single-group unions
size_t offset = 0;
while (offset != queue.size()) {
const auto curr = queue[offset++];
mover.next_cases(curr, cases.find(curr)->second);
}
for (const auto code : queue) { release(code); }
for (const auto code : mirrors) { release(code); }
}
// TODO: helper with mirror
KLSK_NOINLINE static void extend_full_pattern(RawCode seed, size_t size, RangesUnion &a) {
const auto mirror_func = [](const RawCode code, auto spawn) {
const auto m_vrt = code.to_vertical_mirror();
spawn(m_vrt);
if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) {
spawn(m_hor);
spawn(m_vrt.to_horizontal_mirror());
}
};
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) {
std::vector<Group> groups;
for (auto group : group_union.groups()) {
if (group.pattern_id() == pattern_id) {
groups.emplace_back(group);
}
}
extend(seed, size, mirror_func, [&a](const RawCode raw_code) {
const auto code = raw_code.to_common_code().unwrap();
a.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code));
});
std::vector<RangesUnion> tmp {4};
for (auto group : groups) {
tmp[(int)group.toward()] = group.cases();
}
unions.emplace_back(tmp);
}
for (auto head : RangesUnion::Heads) {
std::stable_sort(a.ranges(head).begin(), a.ranges(head).end());
}
return unions;
}
static std::vector<RangesUnion> build_ru_flat() {
KLSK_NOINLINE static void extend_hor_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &c) {
const auto mirror_func = [](const RawCode code, auto spawn) {
if (const auto m_hor = code.to_horizontal_mirror(); m_hor != code) {
spawn(m_hor);
}
};
extend(seed, size, mirror_func, [&a, &c](const RawCode raw_code) {
const auto code = raw_code.to_common_code().unwrap();
a.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code));
// phmap::flat_hash_map<Group, RangesUnion> global_data;
// global_data.reserve(ALL_GROUP_NUM);
const auto code_ = raw_code.to_vertical_mirror().to_common_code().unwrap();
c.ranges(code_ >> 32).emplace_back(static_cast<uint32_t>(code_));
});
std::vector<RangesUnion> global_data;
global_data.resize(ALL_PATTERN_NUM * 4);
for (auto head : RangesUnion::Heads) {
std::stable_sort(a.ranges(head).begin(), a.ranges(head).end());
std::stable_sort(c.ranges(head).begin(), c.ranges(head).end());
}
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
}
const auto gu = GroupUnion::unsafe_create(type_id);
for (auto group : gu.groups()) {
// auto _ = group.cases();
KLSK_NOINLINE static void extend_cen_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b) {
const auto mirror_func = [](const RawCode code, auto spawn) {
spawn(code.to_diagonal_mirror());
};
extend(seed, size, mirror_func, [&a, &b](const RawCode raw_code) {
const auto code = raw_code.to_common_code().unwrap();
a.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code));
// global_data.emplace(group, group.cases());
const auto code_ = raw_code.to_horizontal_mirror().to_common_code().unwrap();
b.ranges(code_ >> 32).emplace_back(static_cast<uint32_t>(code_));
});
const auto index = (PATTERN_OFFSET[type_id] + group.pattern_id()) * 4 + (int)group.toward();
global_data[index] = group.cases();
}
for (auto head : RangesUnion::Heads) {
std::stable_sort(a.ranges(head).begin(), a.ranges(head).end());
std::stable_sort(b.ranges(head).begin(), b.ranges(head).end());
}
}
// std::println("load factor: {}", global_data.load_factor());
KLSK_NOINLINE static void extend_ver_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b) {
const auto mirror_func = [](const RawCode code, auto spawn) {
spawn(code.to_vertical_mirror());
};
extend(seed, size, mirror_func, [&a, &b](const RawCode raw_code) {
const auto code = raw_code.to_common_code().unwrap();
a.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code));
const auto code_ = raw_code.to_horizontal_mirror().to_common_code().unwrap();
b.ranges(code_ >> 32).emplace_back(static_cast<uint32_t>(code_));
});
return global_data;
for (auto head : RangesUnion::Heads) {
std::stable_sort(a.ranges(head).begin(), a.ranges(head).end());
std::stable_sort(b.ranges(head).begin(), b.ranges(head).end());
}
}
static phmap::flat_hash_map<Group, RangesUnion> build_ru_map() {
KLSK_NOINLINE static void extend_ord_pattern(RawCode seed, size_t size, RangesUnion &a, RangesUnion &b, RangesUnion &c, RangesUnion &d) {
extend(seed, size, NO_MIRROR, [&a, &b, &c, &d](const RawCode raw_code) {
const auto code = raw_code.to_common_code().unwrap();
a.ranges(code >> 32).emplace_back(static_cast<uint32_t>(code));
phmap::flat_hash_map<Group, RangesUnion> global_data;
global_data.reserve(ALL_GROUP_NUM);
const auto code_1 = raw_code.to_horizontal_mirror().to_common_code().unwrap();
b.ranges(code_1 >> 32).emplace_back(static_cast<uint32_t>(code_1));
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
const auto code_2 = raw_code.to_vertical_mirror().to_common_code().unwrap();
c.ranges(code_2 >> 32).emplace_back(static_cast<uint32_t>(code_2));
const auto gu = GroupUnion::unsafe_create(type_id);
for (auto group : gu.groups()) {
global_data.emplace(group, group.cases());
}
}
const auto code_3 = raw_code.to_diagonal_mirror().to_common_code().unwrap();
d.ranges(code_3 >> 32).emplace_back(static_cast<uint32_t>(code_3));
});
return global_data;
for (auto head : RangesUnion::Heads) {
std::stable_sort(a.ranges(head).begin(), a.ranges(head).end());
std::stable_sort(b.ranges(head).begin(), b.ranges(head).end());
std::stable_sort(c.ranges(head).begin(), c.ranges(head).end());
std::stable_sort(d.ranges(head).begin(), d.ranges(head).end());
}
}
// std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> build_ru_array() {
std::array<RangesUnion, ALL_PATTERN_NUM * 4> build_ru_array() {
// std::vector<std::array<RangesUnion, 4>> data;
// std::vector<ru_array_t> data;
// std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> data;
std::array<RangesUnion, ALL_PATTERN_NUM * 4> data;
// data.reserve(ALL_PATTERN_NUM);
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) {
std::vector<Group> groups;
for (auto group : group_union.groups()) {
if (group.pattern_id() == pattern_id) {
groups.emplace_back(group);
}
}
auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
for (auto group : groups) {
data[flat_id * 4 + (int)group.toward()] = group.cases();
}
if (group_union.group_num() == 1) { // only single group
// std::println("type_id = {}", type_id);
data[PATTERN_OFFSET[type_id] * 4] = group_union.cases();
continue;
}
// ru_array_t tmp {};
uint32_t pattern_id_begin = 0;
if ((PATTERN_DATA[PATTERN_OFFSET[type_id]] & 0b111) == 0) { // first pattern is `x`
pattern_id_begin = 1;
}
for (uint32_t pattern_id = pattern_id_begin; pattern_id < group_union.pattern_num(); ++pattern_id) {
// std::vector<Group> groups;
// for (auto group : group_union.groups()) {
// if (group.pattern_id() == pattern_id) {
// groups.emplace_back(group);
// }
// }
//
// auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
// for (auto group : groups) {
// tmp.arr[(int)group.toward()] = group.cases();
// data[flat_id * 4 + (int)group.toward()] = Group_cases(group);
// }
// data.emplace_back(tmp);
const auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
const auto mirror_type = static_cast<Group::MirrorType>(PATTERN_DATA[flat_id] & 0b111);
const auto seed_val = PATTERN_DATA[flat_id] >> 23;
auto seed = CommonCode::unsafe_create(seed_val).to_raw_code();
const auto size = (PATTERN_DATA[flat_id] >> 3) & 0xFFFFF;
if (mirror_type == Group::MirrorType::Full) {
extend_full_pattern(seed, size, data[flat_id * 4]);
} else if (mirror_type == Group::MirrorType::Horizontal) {
extend_hor_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 2]);
} else if (mirror_type == Group::MirrorType::Centro) {
extend_cen_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1]);
} else if (mirror_type == Group::MirrorType::Vertical) {
extend_ver_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1]);
} else if (mirror_type == Group::MirrorType::Ordinary) {
extend_ord_pattern(seed, size, data[flat_id * 4], data[flat_id * 4 + 1], data[flat_id * 4 + 2], data[flat_id * 4 + 3]);
}
}
}
return data;
}
std::vector<std::vector<RangesUnion>> build_ru_vector() {
std::vector<std::vector<RangesUnion>> data;
data.reserve(ALL_GROUP_NUM);
if (pattern_id_begin == 1) { // first pattern is `x`
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) {
std::vector<Group> groups;
for (auto group : group_union.groups()) {
if (group.pattern_id() == pattern_id) {
groups.emplace_back(group);
}
RangesUnion others;
size_t index_begin = (PATTERN_OFFSET[type_id] + 1) * 4;
size_t index_end = (PATTERN_OFFSET[type_id] + group_union.pattern_num()) * 4;
for (size_t index = index_begin; index < index_end; ++index) {
others += data[index];
}
std::vector<RangesUnion> tmp {};
tmp.resize(4);
for (auto group : groups) {
tmp[(int)group.toward()] = group.cases();
for (auto head : RangesUnion::Heads) {
std::stable_sort(others.ranges(head).begin(), others.ranges(head).end());
}
data.emplace_back(tmp);
data[PATTERN_OFFSET[type_id] * 4] = group_union.cases_without(others);
}
}
// verify
// for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
// const auto group_union = GroupUnion::unsafe_create(type_id);
// for (const auto group : group_union.groups()) {
// const auto flat_id = PATTERN_OFFSET[type_id] + group.pattern_id();
// const auto index = flat_id * 4 + (int)group.toward();
//
// const auto a = data[index].codes();
// const auto b = group.cases().codes();
// if (a != b) {
// std::cout << "!!! error: " << group << " | " << a.size() << " vs " << b.size() << std::endl;
// }
// }
// }
return data;
}
@ -226,22 +320,10 @@ void GroupCases::build() {
// ru_data = &data_1;
// rev_data = &data_2;
// about 35.7ns
// static auto data_flat = build_ru_flat();
// ru_data_flat = &data_flat;
// about 108ns
// static auto data_map = build_ru_map();
// ru_data_map = &data_map;
// about 34.1ns
static auto data_array = build_ru_array();
ru_data_array = &data_array;
// about 34.4ns
// static auto data_vector = build_ru_vector();
// ru_data_vector = &data_vector;
// KLSK_MEM_BARRIER;
// fast_ = true;
}

2
src/core/group/internal/group_union.inl

@ -18,7 +18,7 @@ constexpr uint32_t GroupUnion::group_num() const {
return GROUP_NUM[type_id_];
}
constexpr uint32_t GroupUnion::pattern_num() const {
constexpr uint32_t GroupUnion::pattern_num() const { // TODO: why not using `uint16_t`
return PATTERN_NUM[type_id_];
}

31
src/core/main.cc

@ -31,6 +31,7 @@ using klotski::codec::CommonCode;
using klotski::group::GroupUnion;
using klotski::group::Group;
using klotski::group::CaseInfo;
using klotski::group::GroupCases;
using klotski::group::GroupUnion;
@ -129,18 +130,24 @@ int main() {
// perf-b -> ~1311ms
// perf-c -> ~1272ms
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
for (auto group: GroupUnion::unsafe_create(type_id).groups()) {
if (group.mirror_type() == Group::MirrorType::Full) {
// if (group.mirror_type() == Group::MirrorType::Horizontal) {
// if (group.mirror_type() == Group::MirrorType::Centro) {
// if (group.mirror_type() == Group::MirrorType::Vertical) {
// if (group.mirror_type() == Group::MirrorType::Ordinary) {
// std::println("{} ({})", group.to_string(), group.size());
volatile auto kk = group.cases();
}
}
}
// for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
// for (auto group: GroupUnion::unsafe_create(type_id).groups()) {
// // if (group.mirror_type() == Group::MirrorType::Full) {
// // if (group.mirror_type() == Group::MirrorType::Horizontal) {
// // if (group.mirror_type() == Group::MirrorType::Centro) {
// // if (group.mirror_type() == Group::MirrorType::Vertical) {
// // if (group.mirror_type() == Group::MirrorType::Ordinary) {
// // std::println("{} ({})", group.to_string(), group.size());
// volatile auto kk = group.cases();
// // }
// }
// }
GroupCases::build();
constexpr auto group = Group::unsafe_create(169, 0, Group::Toward::C);
constexpr auto info = CaseInfo::unsafe_create(group, 7472);
std::cout << info << ": " << GroupCases::obtain_code(info) << std::endl;
// constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A);
// std::cout << group.to_string() << std::endl;

Loading…
Cancel
Save