Browse Source

feat: more interfaces of `RangesUnion`

legacy
Dnomd343 1 month ago
parent
commit
cdd7f95f59
  1. 4
      src/core/all_cases/internal/all_cases.cc
  2. 28
      src/core/benchmark/ranges.cc
  3. 4
      src/core/group/internal/group.cc
  4. 12
      src/core/group/internal/group_cases.cc
  5. 10
      src/core/group/internal/group_union.cc
  6. 37
      src/core/ranges/internal/ranges.cc
  7. 12
      src/core/ranges/ranges.h
  8. 4
      src/core/short_code/internal/convert.cc
  9. 6
      src/core_test/cases/all_cases.cc
  10. 20
      src/core_test/cases/ranges.cc
  11. 4
      src/core_test/cases/ranges_union.cc
  12. 6
      src/core_test/mover/mover.cc

4
src/core/all_cases/internal/all_cases.cc

@ -65,7 +65,7 @@ void AllCases::build() {
x = range_reverse(x); x = range_reverse(x);
} }
for (const auto head : get_heads()) { for (const auto head : get_heads()) {
build_cases(ranges, reversed, get_cases()[head], head); build_cases(ranges, reversed, get_cases().ranges(head), head);
} }
available_ = true; available_ = true;
KLSK_MEM_BARRIER; KLSK_MEM_BARRIER;
@ -92,7 +92,7 @@ void AllCases::build_async(Executor &&executor, Notifier &&callback) {
Worker worker {executor}; Worker worker {executor};
for (const auto head : get_heads()) { for (const auto head : get_heads()) {
worker.post([head, reversed] { worker.post([head, reversed] {
build_cases(BasicRanges::instance().fetch(), *reversed, get_cases()[head], head); build_cases(BasicRanges::instance().fetch(), *reversed, get_cases().ranges(head), head);
}); });
} }

28
src/core/benchmark/ranges.cc

@ -31,8 +31,34 @@ static void RangesUnionExport(benchmark::State &state) {
} }
} }
BENCHMARK(SpawnRanges)->Unit(benchmark::kMillisecond); static void RangesSize(benchmark::State &state) {
auto &all_cases = AllCases::instance().fetch();
// std::cout << all_cases.size() << std::endl;
for (auto _ : state) {
volatile auto k1 = all_cases.size();
volatile auto k2 = all_cases.size();
volatile auto k3 = all_cases.size();
volatile auto k4 = all_cases.size();
volatile auto k5 = all_cases.size();
volatile auto k6 = all_cases.size();
volatile auto k7 = all_cases.size();
volatile auto k8 = all_cases.size();
volatile auto p1 = all_cases.size();
volatile auto p2 = all_cases.size();
volatile auto p3 = all_cases.size();
volatile auto p4 = all_cases.size();
volatile auto p5 = all_cases.size();
volatile auto p6 = all_cases.size();
volatile auto p7 = all_cases.size();
volatile auto p8 = all_cases.size();
}
}
// BENCHMARK(SpawnRanges)->Unit(benchmark::kMillisecond);
// BENCHMARK(RangesUnionExport)->Unit(benchmark::kMillisecond); // BENCHMARK(RangesUnionExport)->Unit(benchmark::kMillisecond);
BENCHMARK(RangesSize);
BENCHMARK_MAIN(); BENCHMARK_MAIN();

4
src/core/group/internal/group.cc

@ -68,11 +68,11 @@ RangesUnion Group::cases() const {
RangesUnion data; RangesUnion data;
for (auto raw_code : codes) { for (auto raw_code : codes) {
auto common_code = raw_code.to_common_code().unwrap(); auto common_code = raw_code.to_common_code().unwrap();
data[common_code >> 32].emplace_back(static_cast<uint32_t>(common_code)); data.ranges(common_code >> 32).emplace_back(static_cast<uint32_t>(common_code));
} }
for (int head = 0; head < 16; ++head) { for (int head = 0; head < 16; ++head) {
std::stable_sort(data[head].begin(), data[head].end()); std::stable_sort(data.ranges(head).begin(), data.ranges(head).end());
} }
return data; return data;
} }

12
src/core/group/internal/group_cases.cc

@ -119,15 +119,15 @@ CommonCode GroupCases::fast_obtain_code(CaseInfo info) {
auto case_id = info.case_id(); auto case_id = info.case_id();
for (;;) { for (;;) {
if (case_id >= cases[head].size()) { if (case_id >= cases.ranges(head).size()) {
case_id -= cases[head].size(); case_id -= cases.ranges(head).size();
++head; ++head;
} else { } else {
break; break;
} }
} }
auto range = cases[head][case_id]; auto range = cases.ranges(head)[case_id];
return CommonCode::unsafe_create(head << 32 | range); return CommonCode::unsafe_create(head << 32 | range);
} }
@ -173,15 +173,15 @@ CommonCode GroupCases::tiny_obtain_code(CaseInfo info) {
auto case_id = info.case_id(); auto case_id = info.case_id();
for (;;) { for (;;) {
if (case_id >= cases[head].size()) { if (case_id >= cases.ranges(head).size()) {
case_id -= cases[head].size(); case_id -= cases.ranges(head).size();
++head; ++head;
} else { } else {
break; break;
} }
} }
auto range = cases[head][case_id]; auto range = cases.ranges(head)[case_id];
return CommonCode::unsafe_create(head << 32 | range); return CommonCode::unsafe_create(head << 32 | range);
} }

10
src/core/group/internal/group_union.cc

@ -7,7 +7,7 @@ using klotski::cases::RangesUnion;
using klotski::cases::BASIC_RANGES_NUM; using klotski::cases::BASIC_RANGES_NUM;
#define RANGE_DERIVE(HEAD) ranges.derive(HEAD, cases[HEAD]) #define RANGE_DERIVE(HEAD) ranges.derive(HEAD, cases.ranges(HEAD))
RangesUnion GroupUnion::cases() const { RangesUnion GroupUnion::cases() const {
auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id_]; auto [n, n_2x1, n_1x1] = BLOCK_NUM[type_id_];
@ -19,10 +19,10 @@ RangesUnion GroupUnion::cases() const {
ranges.reverse(); ranges.reverse();
RangesUnion cases; RangesUnion cases;
cases[0x0].reserve(s_a); cases[0x1].reserve(s_b); cases[0x2].reserve(s_a); cases.ranges(0x0).reserve(s_a); cases.ranges(0x1).reserve(s_b); cases.ranges(0x2).reserve(s_a);
cases[0x4].reserve(s_c); cases[0x5].reserve(s_d); cases[0x6].reserve(s_c); cases.ranges(0x4).reserve(s_c); cases.ranges(0x5).reserve(s_d); cases.ranges(0x6).reserve(s_c);
cases[0x8].reserve(s_c); cases[0x9].reserve(s_d); cases[0xA].reserve(s_c); cases.ranges(0x8).reserve(s_c); cases.ranges(0x9).reserve(s_d); cases.ranges(0xA).reserve(s_c);
cases[0xC].reserve(s_a); cases[0xD].reserve(s_b); cases[0xE].reserve(s_a); cases.ranges(0xC).reserve(s_a); cases.ranges(0xD).reserve(s_b); cases.ranges(0xE).reserve(s_a);
RANGE_DERIVE(0x0); RANGE_DERIVE(0x1); RANGE_DERIVE(0x2); RANGE_DERIVE(0x0); RANGE_DERIVE(0x1); RANGE_DERIVE(0x2);
RANGE_DERIVE(0x4); RANGE_DERIVE(0x5); RANGE_DERIVE(0x6); RANGE_DERIVE(0x4); RANGE_DERIVE(0x5); RANGE_DERIVE(0x6);

37
src/core/ranges/internal/ranges.cc

@ -23,23 +23,44 @@ Ranges& Ranges::operator+=(const Ranges &ranges) {
RangesUnion& RangesUnion::operator+=(const RangesUnion &ranges_union) { RangesUnion& RangesUnion::operator+=(const RangesUnion &ranges_union) {
for (const auto head : heads) { for (const auto head : heads) {
(*this)[head] += ranges_union[head]; // (*this)[head] += ranges_union[head];
std::array<Ranges, 16>::operator[](head) += ranges_union.std::array<Ranges, 16>::operator[](head);
} }
return *this; return *this;
} }
std::vector<CommonCode> RangesUnion::codes() const { std::vector<CommonCode> RangesUnion::codes() const {
size_type size = 0;
for (const auto head : heads) {
size += (*this)[head].size();
}
std::vector<CommonCode> codes; std::vector<CommonCode> codes;
codes.reserve(size); codes.reserve(size());
for (const auto head : heads) { for (const auto head : heads) {
for (const auto range : (*this)[head]) { // for (const auto range : (*this)[head]) {
for (const auto range : ranges(head)) {
codes.emplace_back(CommonCode::unsafe_create(head << 32 | range)); codes.emplace_back(CommonCode::unsafe_create(head << 32 | range));
} }
} }
// TODO: try using std::views
return codes; return codes;
} }
// TODO: move to `.inl` file
size_t RangesUnion::size() const {
size_type size = 0;
for (const auto head : heads) {
size += std::array<Ranges, 16>::operator[](head).size();
// size += (*this)[head].size();
}
return size;
}
uint32_t RangesUnion::operator[](size_type index) const {
size_t head = 0;
for (;;) {
if (index >= std::array<Ranges, 16>::operator[](head).size()) {
index -= std::array<Ranges, 16>::operator[](head).size();
++head;
} else {
break;
}
}
return std::array<Ranges, 16>::operator[](head)[index];
}

12
src/core/ranges/ranges.h

@ -41,6 +41,18 @@ public:
/// Export the RangesUnion as CommonCode list. /// Export the RangesUnion as CommonCode list.
[[nodiscard]] std::vector<codec::CommonCode> codes() const; [[nodiscard]] std::vector<codec::CommonCode> codes() const;
[[nodiscard]] const Ranges& ranges(const size_t head) const {
return std::array<Ranges, 16>::operator[](head);
}
Ranges& ranges(const size_t head) {
return std::array<Ranges, 16>::operator[](head);
}
[[nodiscard]] size_t size() const;
[[nodiscard]] uint32_t operator[](size_type) const;
}; };
} // namespace klotski::cases } // namespace klotski::cases

4
src/core/short_code/internal/convert.cc

@ -52,7 +52,7 @@ static uint32_t check_range(uint32_t head, uint32_t range) noexcept {
uint32_t ShortCode::fast_encode(uint64_t common_code) { uint32_t ShortCode::fast_encode(uint64_t common_code) {
auto head = common_code >> 32; auto head = common_code >> 32;
auto &ranges = (*cases_)[head]; // match available ranges const auto &ranges = (*cases_).ranges(head); // match available ranges
// TODO: try to narrow the scope by prefix // TODO: try to narrow the scope by prefix
auto target = std::lower_bound(ranges.begin(), ranges.end(), (uint32_t)common_code); auto target = std::lower_bound(ranges.begin(), ranges.end(), (uint32_t)common_code);
return ALL_CASES_OFFSET[head] + (target - ranges.begin()); return ALL_CASES_OFFSET[head] + (target - ranges.begin());
@ -62,7 +62,7 @@ uint64_t ShortCode::fast_decode(uint32_t short_code) {
auto offset = std::upper_bound(ALL_CASES_OFFSET.begin(), ALL_CASES_OFFSET.end(), short_code) - 1; auto offset = std::upper_bound(ALL_CASES_OFFSET.begin(), ALL_CASES_OFFSET.end(), short_code) - 1;
uint64_t head = offset - ALL_CASES_OFFSET.begin(); uint64_t head = offset - ALL_CASES_OFFSET.begin();
// return (head << 32) | AllCases::instance().fetch()[head][short_code - *offset]; // return (head << 32) | AllCases::instance().fetch()[head][short_code - *offset];
return (head << 32) | (*cases_)[head][short_code - *offset]; return (head << 32) | (*cases_).ranges(head)[short_code - *offset];
} }
uint32_t ShortCode::tiny_encode(uint64_t common_code) { uint32_t ShortCode::tiny_encode(uint64_t common_code) {

6
src/core_test/cases/all_cases.cc

@ -51,15 +51,15 @@ protected:
static void Verify() { static void Verify() {
const auto &all_cases = AllCases::instance().fetch(); const auto &all_cases = AllCases::instance().fetch();
for (int head = 0; head < 16; ++head) { for (int head = 0; head < 16; ++head) {
EXPECT_EQ(all_cases[head].size(), ALL_CASES_NUM[head]); // verify all cases size EXPECT_EQ(all_cases.ranges(head).size(), ALL_CASES_NUM[head]); // verify all cases size
EXPECT_EQ(helper::xxh3(all_cases[head]), ALL_CASES_XXH3[head]); // verify all cases checksum EXPECT_EQ(helper::xxh3(all_cases.ranges(head)), ALL_CASES_XXH3[head]); // verify all cases checksum
} }
} }
}; };
TEST_FF(AllCases, content) { TEST_FF(AllCases, content) {
for (auto head : Heads) { for (auto head : Heads) {
auto &cases = AllCases::instance().fetch()[head]; auto &cases = AllCases::instance().fetch().ranges(head);
EXPECT_SORTED_AND_UNIQUE(cases); EXPECT_SORTED_AND_UNIQUE(cases);
EXPECT_EQ(cases.size(), ALL_CASES_NUM[head]); // size verify EXPECT_EQ(cases.size(), ALL_CASES_NUM[head]); // size verify
EXPECT_EQ(helper::xxh3(cases), ALL_CASES_XXH3[head]); // checksum verify EXPECT_EQ(helper::xxh3(cases), ALL_CASES_XXH3[head]); // checksum verify

20
src/core_test/cases/ranges.cc

@ -16,7 +16,7 @@ TEST(Ranges, check) {
for (const auto head : Heads) { for (const auto head : Heads) {
for (auto range : BasicRanges::instance().fetch()) { for (auto range : BasicRanges::instance().fetch()) {
if (Ranges::check(head, range_reverse(range)) == 0) { if (Ranges::check(head, range_reverse(range)) == 0) {
all_cases[head].emplace_back(range); // found valid cases all_cases.ranges(head).emplace_back(range); // found valid cases
} }
} }
} }
@ -45,14 +45,14 @@ TEST(Ranges, derive) {
RangesUnion cases; RangesUnion cases;
for (const auto head : Heads) { for (const auto head : Heads) {
ranges.derive(head, cases[head]); ranges.derive(head, cases.ranges(head));
EXPECT_SORTED_AND_UNIQUE(cases[head]); // sorted and unique EXPECT_SORTED_AND_UNIQUE(cases.ranges(head)); // sorted and unique
EXPECT_COMMON_CODES(head, cases[head]); // verify common codes EXPECT_COMMON_CODES(head, cases.ranges(head)); // verify common codes
} }
ranges.reverse(); ranges.reverse();
for (const auto head : Heads) { for (const auto head : Heads) {
EXPECT_SUBSET(ranges, cases[head]); // derive ranges is subset EXPECT_SUBSET(ranges, cases.ranges(head)); // derive ranges is subset
} }
} }
} }
@ -83,7 +83,7 @@ TEST(Ranges, combine) {
RangesUnion all_cases; RangesUnion all_cases;
all_ranges.reserve(BASIC_RANGES_NUM_); // pre reserve all_ranges.reserve(BASIC_RANGES_NUM_); // pre reserve
for (const auto head : Heads) { for (const auto head : Heads) {
all_cases[head].reserve(ALL_CASES_NUM[head]); // pre reserve all_cases.ranges(head).reserve(ALL_CASES_NUM[head]); // pre reserve
} }
for (auto [n, n_2x1, n_1x1] : BLOCK_NUM) { for (auto [n, n_2x1, n_1x1] : BLOCK_NUM) {
@ -92,21 +92,21 @@ TEST(Ranges, combine) {
all_ranges += ranges; all_ranges += ranges;
ranges.reverse(); // reverse ranges for derive ranges.reverse(); // reverse ranges for derive
for (const auto head : Heads) { for (const auto head : Heads) {
ranges.derive(head, all_cases[head]); // derive from sub ranges ranges.derive(head, all_cases.ranges(head)); // derive from sub ranges
} }
} }
std::ranges::stable_sort(all_ranges.begin(), all_ranges.end()); std::ranges::stable_sort(all_ranges.begin(), all_ranges.end());
for (const auto head : Heads) { for (const auto head : Heads) {
std::ranges::stable_sort(all_cases[head].begin(), all_cases[head].end()); std::ranges::stable_sort(all_cases.ranges(head).begin(), all_cases.ranges(head).end());
} }
EXPECT_EQ(all_ranges, BasicRanges::instance().fetch()); // verify all ranges EXPECT_EQ(all_ranges, BasicRanges::instance().fetch()); // verify all ranges
EXPECT_EQ(all_cases, AllCases::instance().fetch()); // verify all cases EXPECT_EQ(all_cases, AllCases::instance().fetch()); // verify all cases
all_ranges.reverse(); // reverse ranges for derive all_ranges.reverse(); // reverse ranges for derive
for (const auto head : Heads) { for (const auto head : Heads) {
all_cases[head].clear(); all_cases.ranges(head).clear();
all_ranges.derive(head, all_cases[head]); // derive from all ranges all_ranges.derive(head, all_cases.ranges(head)); // derive from all ranges
} }
EXPECT_EQ(all_cases, AllCases::instance().fetch()); // verify content EXPECT_EQ(all_cases, AllCases::instance().fetch()); // verify content
} }

4
src/core_test/cases/ranges_union.cc

@ -27,14 +27,14 @@ TEST(RangesUnion, append) {
RangesUnion ru; RangesUnion ru;
for (const auto head : Heads) { for (const auto head : Heads) {
r.derive(head, ru[head]); r.derive(head, ru.ranges(head));
} }
auto &tmp = cases += ru; auto &tmp = cases += ru;
EXPECT_EQ(tmp, cases); // reference of cases EXPECT_EQ(tmp, cases); // reference of cases
} }
for (const auto head : Heads) { for (const auto head : Heads) {
std::stable_sort(cases[head].begin(), cases[head].end()); std::stable_sort(cases.ranges(head).begin(), cases.ranges(head).end());
} }
EXPECT_EQ(cases, AllCases::instance().fetch()); EXPECT_EQ(cases, AllCases::instance().fetch());
} }

6
src/core_test/mover/mover.cc

@ -38,7 +38,7 @@ TEST(Core, core) {
// codes.reserve(klotski::cases::ALL_CASES_NUM_); // codes.reserve(klotski::cases::ALL_CASES_NUM_);
for (uint64_t head = 0; head < 16; ++head) { for (uint64_t head = 0; head < 16; ++head) {
for (const auto range : AllCases::instance().fetch()[head]) { for (const auto range : AllCases::instance().fetch().ranges(head)) {
auto common_code = CommonCode::unsafe_create(head << 32 | range); auto common_code = CommonCode::unsafe_create(head << 32 | range);
auto raw_code = common_code.to_raw_code().unwrap(); auto raw_code = common_code.to_raw_code().unwrap();
@ -73,7 +73,7 @@ TEST(Core, mask) {
raw_codes.reserve(klotski::cases::ALL_CASES_NUM_); raw_codes.reserve(klotski::cases::ALL_CASES_NUM_);
for (uint64_t head = 0; head < 16; ++head) { for (uint64_t head = 0; head < 16; ++head) {
for (const auto range : AllCases::instance().fetch()[head]) { for (const auto range : AllCases::instance().fetch().ranges(head)) {
auto common_code = CommonCode::unsafe_create(head << 32 | range); auto common_code = CommonCode::unsafe_create(head << 32 | range);
auto raw_code = common_code.to_raw_code().unwrap(); auto raw_code = common_code.to_raw_code().unwrap();
raw_codes.emplace_back(raw_code); raw_codes.emplace_back(raw_code);
@ -173,7 +173,7 @@ TEST(Core, next_cases) {
std::vector<RawCode> raw_codes; std::vector<RawCode> raw_codes;
for (uint64_t head = 0; head < 16; ++head) { for (uint64_t head = 0; head < 16; ++head) {
for (const auto range : AllCases::instance().fetch()[head]) { for (const auto range : AllCases::instance().fetch().ranges(head)) {
auto common_code = CommonCode::unsafe_create(head << 32 | range); auto common_code = CommonCode::unsafe_create(head << 32 | range);
auto raw_code = common_code.to_raw_code(); auto raw_code = common_code.to_raw_code();

Loading…
Cancel
Save