Browse Source

perf: pre-build of all group cases

master
Dnomd343 1 month ago
parent
commit
ddc16c3d45
  1. 80
      src/core/benchmark/group.cc
  2. 171
      src/core/group/internal/group_cases.cc

80
src/core/benchmark/group.cc

@ -134,24 +134,24 @@ static void GroupExtend(benchmark::State &state) {
for (auto _ : state) {
// for (int type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
// for (auto group : GroupUnion::unsafe_create(type_id).groups()) {
// // if (group.mirror_type() == Group::MirrorType::Full) {
// // if (group.mirror_type() == Group::MirrorType::Horizontal) {
// // if (group.mirror_type() == Group::MirrorType::Centro) {
// // if (group.mirror_type() == Group::MirrorType::Vertical) {
// if (group.mirror_type() == Group::MirrorType::Ordinary) {
// // std::println("{} ({})", group.to_string(), group.size());
// volatile auto kk = group.cases();
// }
// }
// }
for (int type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
for (auto group : GroupUnion::unsafe_create(type_id).groups()) {
if (group.mirror_type() == Group::MirrorType::Full) {
// if (group.mirror_type() == Group::MirrorType::Horizontal) {
// if (group.mirror_type() == Group::MirrorType::Centro) {
// if (group.mirror_type() == Group::MirrorType::Vertical) {
// if (group.mirror_type() == Group::MirrorType::Ordinary) {
// std::println("{} ({})", group.to_string(), group.size());
volatile auto kk = group.cases();
}
}
}
constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A);
// constexpr auto group = Group::unsafe_create(89, 0, Group::Toward::A);
// constexpr auto group = Group::unsafe_create(51, 0, Group::Toward::A);
// constexpr auto group = Group::unsafe_create(98, 0, Group::Toward::A);
volatile auto kk = group.cases();
// volatile auto kk = group.cases();
// for (auto group : groups) {
// volatile auto tmp = group.cases();
@ -370,7 +370,7 @@ static void ToHorizontalMirror(benchmark::State &state) {
}
static void FastObtainCode(benchmark::State &state) {
// GroupCases::build();
GroupCases::build();
// std::vector<CaseInfo> infos;
// for (auto code : common_code_samples(64)) {
@ -379,14 +379,14 @@ static void FastObtainCode(benchmark::State &state) {
// infos.emplace_back(CaseInfo::unsafe_create())
const auto group = Group::unsafe_create(169, 0, Group::Toward::C);
const klotski::cases::RangesUnion data = group.cases();
std::array<size_t, 16> sizes {};
size_t offset = 0;
for (int i = 0; i < 16; ++i) {
sizes[i] = offset;
offset += data.ranges(i).size();
}
// const klotski::cases::RangesUnion data = group.cases();
//
// std::array<size_t, 16> sizes {};
// size_t offset = 0;
// for (int i = 0; i < 16; ++i) {
// sizes[i] = offset;
// offset += data.ranges(i).size();
// }
std::vector infos {
CaseInfo::unsafe_create(group, 2631),
@ -415,24 +415,27 @@ static void FastObtainCode(benchmark::State &state) {
for (auto info : infos) {
/// about 35ns
auto &cases = data;
uint64_t head = 0;
auto case_id = info.case_id();
for (;;) {
if (case_id >= cases.ranges(head).size()) {
case_id -= cases.ranges(head).size();
++head;
} else {
break;
}
}
auto range = cases.ranges(head)[case_id];
volatile auto kk = CommonCode::unsafe_create(head << 32 | range);
// auto &cases = data;
// uint64_t head = 0;
// auto case_id = info.case_id();
// for (;;) {
// if (case_id >= cases.ranges(head).size()) {
// case_id -= cases.ranges(head).size();
// ++head;
// } else {
// break;
// }
// }
// auto range = cases.ranges(head)[case_id];
// volatile auto kk = CommonCode::unsafe_create(head << 32 | range);
/// about 117ns
// uint64_t head = std::upper_bound(sizes.begin(), sizes.end(), info.case_id()) - sizes.begin() - 1;
// uint32_t range = data[head][info.case_id() - sizes[head]];
// volatile auto kk = CommonCode::unsafe_create(head << 32 | range);
volatile auto kk = GroupCases::fast_obtain_code(info);
}
}
@ -441,7 +444,8 @@ static void FastObtainCode(benchmark::State &state) {
// BENCHMARK(CommonCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
// BENCHMARK(RawCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond);
// BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond);
// BENCHMARK(GroupExtend)->Unit(benchmark::kMicrosecond);
// BENCHMARK(FilterFromAllCases)->Unit(benchmark::kMillisecond);
@ -457,7 +461,7 @@ BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond);
// BENCHMARK(GroupFromRawCode)->Unit(benchmark::kMillisecond);
// BENCHMARK(FastObtainCode);
BENCHMARK(FastObtainCode);
// BENCHMARK(IsVerticalMirror);
// BENCHMARK(IsHorizontalMirror);

171
src/core/group/internal/group_cases.cc

@ -26,6 +26,10 @@ using klotski::cases::ALL_CASES_NUM_;
using klotski::group::GROUP_DATA;
using klotski::group::PATTERN_DATA;
using klotski::group::PATTERN_OFFSET;
using klotski::group::ALL_PATTERN_NUM;
struct case_info_t {
uint32_t pattern_id : 10;
uint32_t toward_id : 2;
@ -34,8 +38,24 @@ struct case_info_t {
static_assert(sizeof(case_info_t) == 4);
// struct ru_array_t {
// RangesUnion arr[4];
// };
// TODO: benchmark of `phmap<Group>` structure
static std::vector<std::vector<RangesUnion>> *ru_data = nullptr; // group_offset + toward
// static std::vector<std::vector<RangesUnion>> *ru_data = nullptr; // group_offset + toward
static std::vector<RangesUnion> *ru_data_flat = nullptr;
static phmap::flat_hash_map<Group, RangesUnion> *ru_data_map = nullptr;
// static std::vector<std::array<RangesUnion, 4>> *ru_data_array = nullptr;
// static std::vector<ru_array_t> *ru_data_array = nullptr;
// static std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> *ru_data_array = nullptr;
static std::array<RangesUnion, ALL_PATTERN_NUM * 4> *ru_data_array = nullptr;
static std::vector<std::vector<RangesUnion>> *ru_data_vector = nullptr;
static std::vector<case_info_t> *rev_data = nullptr;
std::vector<std::vector<RangesUnion>> build_ranges_unions() {
@ -66,6 +86,108 @@ std::vector<std::vector<RangesUnion>> build_ranges_unions() {
return unions;
}
static std::vector<RangesUnion> build_ru_flat() {
// phmap::flat_hash_map<Group, RangesUnion> global_data;
// global_data.reserve(ALL_GROUP_NUM);
std::vector<RangesUnion> global_data;
global_data.resize(ALL_PATTERN_NUM * 4);
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
const auto gu = GroupUnion::unsafe_create(type_id);
for (auto group : gu.groups()) {
// auto _ = group.cases();
// global_data.emplace(group, group.cases());
const auto index = (PATTERN_OFFSET[type_id] + group.pattern_id()) * 4 + (int)group.toward();
global_data[index] = group.cases();
}
}
// std::println("load factor: {}", global_data.load_factor());
return global_data;
}
static phmap::flat_hash_map<Group, RangesUnion> build_ru_map() {
phmap::flat_hash_map<Group, RangesUnion> global_data;
global_data.reserve(ALL_GROUP_NUM);
for (uint8_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
const auto gu = GroupUnion::unsafe_create(type_id);
for (auto group : gu.groups()) {
global_data.emplace(group, group.cases());
}
}
return global_data;
}
// std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> build_ru_array() {
std::array<RangesUnion, ALL_PATTERN_NUM * 4> build_ru_array() {
// std::vector<std::array<RangesUnion, 4>> data;
// std::vector<ru_array_t> data;
// std::array<std::array<RangesUnion, 4>, ALL_PATTERN_NUM> data;
std::array<RangesUnion, ALL_PATTERN_NUM * 4> data;
// data.reserve(ALL_PATTERN_NUM);
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) {
std::vector<Group> groups;
for (auto group : group_union.groups()) {
if (group.pattern_id() == pattern_id) {
groups.emplace_back(group);
}
}
auto flat_id = PATTERN_OFFSET[type_id] + pattern_id;
for (auto group : groups) {
data[flat_id * 4 + (int)group.toward()] = group.cases();
}
// ru_array_t tmp {};
// for (auto group : groups) {
// tmp.arr[(int)group.toward()] = group.cases();
// }
// data.emplace_back(tmp);
}
}
return data;
}
std::vector<std::vector<RangesUnion>> build_ru_vector() {
std::vector<std::vector<RangesUnion>> data;
data.reserve(ALL_GROUP_NUM);
for (uint32_t type_id = 0; type_id < TYPE_ID_LIMIT; ++type_id) {
auto group_union = GroupUnion::unsafe_create(type_id);
for (uint32_t pattern_id = 0; pattern_id < group_union.pattern_num(); ++pattern_id) {
std::vector<Group> groups;
for (auto group : group_union.groups()) {
if (group.pattern_id() == pattern_id) {
groups.emplace_back(group);
}
}
std::vector<RangesUnion> tmp {};
tmp.resize(4);
for (auto group : groups) {
tmp[(int)group.toward()] = group.cases();
}
data.emplace_back(tmp);
}
}
return data;
}
static std::vector<case_info_t> build_tmp_data() {
std::vector<case_info_t> data;
data.resize(ALL_CASES_NUM_);
@ -93,19 +215,35 @@ static std::vector<case_info_t> build_tmp_data() {
}
void GroupCases::build() {
if (fast_) {
return;
}
std::lock_guard guard {busy_};
// if (fast_) {
// return;
// }
// std::lock_guard guard {busy_};
// TODO: make `data` as class member
static auto data_1 = build_ranges_unions();
static auto data_2 = build_tmp_data();
ru_data = &data_1;
rev_data = &data_2;
// static auto data_1 = build_ranges_unions();
// static auto data_2 = build_tmp_data();
// ru_data = &data_1;
// rev_data = &data_2;
// about 35.7ns
// static auto data_flat = build_ru_flat();
// ru_data_flat = &data_flat;
// about 108ns
// static auto data_map = build_ru_map();
// ru_data_map = &data_map;
// about 34.1ns
static auto data_array = build_ru_array();
ru_data_array = &data_array;
// about 34.4ns
// static auto data_vector = build_ru_vector();
// ru_data_vector = &data_vector;
KLSK_MEM_BARRIER;
fast_ = true;
// KLSK_MEM_BARRIER;
// fast_ = true;
}
void GroupCases::build_async(Executor &&executor, Notifier &&callback) {
@ -120,9 +258,18 @@ CommonCode GroupCases::fast_obtain_code(CaseInfo info) {
auto flat_id = PATTERN_OFFSET[info.group().type_id()] + info.group().pattern_id();
auto &cases = (*ru_data)[flat_id][(int)info.group().toward()];
// auto &cases = (*ru_data)[flat_id][(int)info.group().toward()];
// TODO: make offset table for perf
// auto &cases = (*ru_data_flat)[flat_id * 4 + (int)info.group().toward()];
// auto &cases = (*ru_data_map)[info.group()];
// auto &cases = (*ru_data_array)[flat_id][(int)info.group().toward()];
auto &cases = (*ru_data_array)[flat_id * 4 + (int)info.group().toward()];
// auto &cases = (*ru_data_vector)[flat_id][(int)info.group().toward()];
uint64_t head = 0;
auto case_id = info.case_id();

Loading…
Cancel
Save