Browse Source

perf: optimize range generation

master
Dnomd343 1 month ago
parent
commit
f9a88410be
  1. 66
      src/core/benchmark/group.cc
  2. 2
      src/core/group/group.h
  3. 128
      src/core/group/internal/group.cc
  4. 33
      src/core/main.cc
  5. 40
      src/core_test/group_tmp/group_extend.cc

66
src/core/benchmark/group.cc

@ -109,9 +109,73 @@ static void GroupExtend(benchmark::State &state) {
}
// static void FilterFromAllCases(benchmark::State &state) {
//
// klotski::cases::AllCases::instance().build();
//
// for (auto _ : state) {
// for (uint64_t head = 0; head < 16; ++head) {
//
// for (const auto range : AllCases::instance().fetch()[head]) {
// uint64_t common_code = head << 32 | range;
//
// volatile auto ret = klotski::cases::common_code_to_type_id(common_code);
//
// }
//
// }
// }
//
// }
static std::vector<std::tuple<int, int, int, int>> target_nums() {
std::vector<std::tuple<int, int, int, int>> results;
for (int n = 0; n <= 7; ++n) {
for (int n_2x1 = 0; n_2x1 <= n; ++n_2x1) {
for (int n_1x1 = 0; n_1x1 <= (14 - n * 2); ++n_1x1) {
results.emplace_back(16 - n*2 - n_1x1, n - n_2x1, n_2x1, n_1x1);
}
}
}
results.resize(203);
return results;
}
static void SpawnRanges(benchmark::State &state) {
auto nums = target_nums();
// std::cout << nums.size() << std::endl;
//
// for (auto [n1, n2, n3, n4] : nums) {
// if (n1 == 2 && n2 == 1 && n3 == 4 && n4 == 4) {
//
// std::cout << "ok" << std::endl;
//
// }
// }
for (auto _ : state) {
// klotski::cases::spawn_ranges(2, 1, 4, 4);
for (auto [n1, n2, n3, n4] : nums) {
klotski::cases::spawn_ranges(n1, n2, n3, n4);
}
}
}
// BENCHMARK(CommonCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
// BENCHMARK(RawCodeToTypeId)->Arg(8)->Arg(64)->Arg(256);
BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond);
// BENCHMARK(GroupExtend)->Unit(benchmark::kMillisecond);
// BENCHMARK(FilterFromAllCases)->Unit(benchmark::kMillisecond);
BENCHMARK(SpawnRanges)->Unit(benchmark::kMillisecond);
BENCHMARK_MAIN();

2
src/core/group/group.h

@ -80,6 +80,8 @@ uint32_t raw_code_to_type_id(uint64_t raw_code);
std::vector<uint64_t> group_extend_from_seed(uint64_t raw_code);
std::vector<uint32_t> spawn_ranges(int n_00, int n_01, int n_10, int n_11);
class Group;
// TODO: add constexpr

128
src/core/group/internal/group.cc

@ -2,10 +2,10 @@
#include <queue>
#include <absl/container/btree_map.h>
#include <absl/container/btree_set.h>
#include <absl/container/flat_hash_map.h>
#include <absl/container/node_hash_map.h>
#include <core/core.h>
static KLSK_INLINE uint32_t type_id(const int n, const int n_2x1, const int n_1x1) {
@ -31,78 +31,126 @@ std::vector<uint64_t> klotski::cases::group_extend_from_seed(uint64_t raw_code)
auto max_size = GroupUnion::create(raw_code_to_type_id(raw_code))->max_group_size();
// std::queue<uint64_t> cache({raw_code});
uint64_t offset = 0;
std::vector<uint64_t> results;
results.reserve(max_size);
results.emplace_back(raw_code);
// uint64_t offset = 0;
// std::vector<std::pair<uint64_t, uint64_t>> results;
// results.reserve(max_size);
// results.emplace_back(raw_code, 0);
absl::flat_hash_map<uint64_t, uint64_t> cases; // <code, mask>
// absl::node_hash_map<uint64_t, uint64_t> cases; // <code, mask>
// std::unordered_map<uint64_t, uint64_t> cases; // <code, mask>
// std::map<uint64_t, uint64_t> cases; // <code, mask>
// absl::btree_map<uint64_t, uint64_t> cases; // <code, mask>
cases.reserve(max_size);
cases.emplace(raw_code, 0); // without mask
// std::cout << max_size << std::endl;
auto core = klotski::core::Core(
[&results, &cases](auto code, auto mask) { // callback function
auto current = cases.find(code);
if (current != cases.end()) {
current->second |= mask; // update mask
// results[current->second].second |= mask; // update mask
return;
}
cases.emplace(code, mask);
// cases.emplace(code, results.size());
// cache.emplace(code);
results.emplace_back(code);
// cache.emplace(code, 0);
// results.emplace_back(code, 0);
}
);
while (offset != results.size()) {
auto tmp = results[offset];
core.next_cases(tmp, cases.find(tmp)->second);
// core.next_cases(tmp, 0);
// core.next_cases(tmp.first, tmp.second);
++offset;
}
return results;
// return {};
}
template<int N>
static std::vector<uint32_t> demo(int n_10, int n_11) {
// while (!cache.empty()) { // until BFS without elements
// core.next_cases(cache.front(), cases.find(cache.front())->second);
// cache.pop(); // case dequeue
// }
constexpr auto num = 16 - N;
constexpr auto offset = (16 - num) << 1; // offset of low bits
int n_00 = 16 - N * 2 - n_11;
int n_01 = N - n_10;
// std::vector<uint64_t> result;
// result.reserve(cases.size());
// for (auto &&tmp : cases) { // export group cases
// result.emplace_back(tmp.first);
std::array<int, num> series {};
// for (int k = 0; k < n_00; ++k) {
// series[k] = 0b00;
// }
// for (int k = n_00; k < n_00 + n_01; ++k) {
// series[k] = 0b01;
// }
// return result;
// std::vector<uint64_t> result;
// result.reserve(results.size());
// for (auto [code, _] : results) {
// result.emplace_back(code);
auto kk = std::fill_n(series.begin() + n_00, n_01, 0b01);
auto pp = std::fill_n(kk, n_10, 0b10);
std::fill_n(pp, n_11, 0b11);
// for (auto x : series) {
// std::cout << x << " ";
// }
// return result;
// std::cout << std::endl;
std::vector<uint32_t> ranges;
do {
uint32_t range = 0;
for (const auto x : series) // store every 2-bit
(range <<= 2) |= x;
ranges.emplace_back(range << offset);
} while (std::ranges::next_permutation(series).found);
return ranges;
}
std::vector<uint32_t> klotski::cases::spawn_ranges(int n_00, int n_01, int n_10, int n_11) {
auto n = n_01 + n_10;
switch (n) {
case 0: return demo<0>(n_10, n_11);
case 1: return demo<1>(n_10, n_11);
case 2: return demo<2>(n_10, n_11);
case 3: return demo<3>(n_10, n_11);
case 4: return demo<4>(n_10, n_11);
case 5: return demo<5>(n_10, n_11);
case 6: return demo<6>(n_10, n_11);
case 7: return demo<7>(n_10, n_11);
default: return {};
}
// return demo<5>(n_10, n_11);
// auto num = n_00 + n_01 + n_10 + n_11;
// auto offset = (16 - num) << 1; // offset of low bits
// std::vector<int> series;
// series.reserve(num);
// series.insert(series.end(), n_00, 0b00);
// series.insert(series.end(), n_01, 0b01);
// series.insert(series.end(), n_10, 0b10);
// series.insert(series.end(), n_11, 0b11);
// std::array<int, 11> series {
// 0b00, 0b00,
// 0b01,
// 0b10, 0b10, 0b10, 0b10,
// 0b11, 0b11, 0b11, 0b11,
// };
// std::vector<uint32_t> ranges;
// do { // full permutation traversal
// uint32_t range = 0;
// for (const auto x : series) // store every 2-bit
// (range <<= 2) |= x;
// ranges.emplace_back(range << offset);
// } while (std::next_permutation(series.begin(), series.end()));
// do {
// uint32_t range = 0;
// for (const auto x : series) // store every 2-bit
// (range <<= 2) |= x;
// ranges.emplace_back(range << offset);
// } while (std::ranges::next_permutation(series).found);
// return ranges;
}

33
src/core/main.cc

@ -1,3 +1,4 @@
#include <algorithm>
#include <thread>
#include <iostream>
#include <format>
@ -25,10 +26,36 @@ using klotski::codec::SHORT_CODE_LIMIT;
int main() {
const auto start = clock();
auto raw_code = RawCode::from_common_code(0x1A9BF0C00)->unwrap();
auto ret = klotski::cases::group_extend_from_seed(raw_code);
klotski::cases::spawn_ranges(2, 1, 4, 4);
std::cout << ret.size() << std::endl;
// std::vector<int> series {1, 2, 3, 4};
// do { // full permutation traversal
//
// for (auto s : series) {
// std::cout << s << " ";
// }
// std::cout << std::endl;
//
// } while (std::next_permutation(series.begin(), series.end()));
// std::array a{'a', 'b', 'c'};
// do {
// for (auto x : a) {
// std::cout << x;
// }
// std::cout << std::endl;
// }
// while (std::ranges::next_permutation(a).found);
// auto raw_code = RawCode::from_common_code(0x1A9BF0C00)->unwrap();
// auto ret = klotski::cases::group_extend_from_seed(raw_code);
//
// std::cout << ret.size() << std::endl;
// 1 A9BF0C00 -> 10 10 10 01 10 11 11 11 00 00 11 000000
// auto ret = klotski::cases::spawn_ranges(2, 1, 4, 4);
// std::cout << ret.size() << std::endl;
// auto kk = GroupUnion::create(123).value();
// std::cout << kk.size() << std::endl;

40
src/core_test/group_tmp/group_extend.cc

@ -20,3 +20,43 @@ TEST(Group, group_extend) {
auto hash_ret = hash::xxh3(codes.data(), codes.size() * sizeof(uint64_t));
EXPECT_EQ(hash_ret, 0x91BD28A749312A6D);
}
static std::vector<std::tuple<int, int, int, int>> target_nums() {
std::vector<std::tuple<int, int, int, int>> results;
for (int n = 0; n <= 7; ++n) {
for (int n_2x1 = 0; n_2x1 <= n; ++n_2x1) {
for (int n_1x1 = 0; n_1x1 <= (14 - n * 2); ++n_1x1) {
results.emplace_back(16 - n*2 - n_1x1, n - n_2x1, n_2x1, n_1x1);
}
}
}
// results.resize(203);
return results;
}
TEST(Group, ranges) {
// auto ret = klotski::cases::spawn_ranges(2, 1, 4, 4);
//
// EXPECT_EQ(ret.size(), 34650);
//
// auto hash_ret = hash::xxh3(ret.data(), ret.size() * 4);
// EXPECT_EQ(hash_ret, 0xF6F87606E4205EAF);
std::vector<uint32_t> ranges;
for (auto [n1, n2, n3, n4] : target_nums()) {
auto kk = klotski::cases::spawn_ranges(n1, n2, n3, n4);
ranges.insert(ranges.end(), kk.begin(), kk.end());
}
EXPECT_EQ(ranges.size(), 7311921);
auto hash_ret = hash::xxh3(ranges.data(), ranges.size() * 4);
EXPECT_EQ(hash_ret, 0xA1E247B01D5A9545);
}

Loading…
Cancel
Save