diff --git a/src/core/all_cases/all_cases.cc b/src/core/all_cases/all_cases.cc index 9f0f48e..4deac7b 100644 --- a/src/core/all_cases/all_cases.cc +++ b/src/core/all_cases/all_cases.cc @@ -51,6 +51,7 @@ static int check_range(int head, uint32_t range) noexcept { /// Build all valid ranges of the specified head. void AllCases::BuildCases(int head, Ranges &release) noexcept { + release.clear(); release.reserve(ALL_CASES_NUM[head]); auto &basic_ranges = BasicRanges::Instance().Fetch(); for (uint32_t index = 0; index < basic_ranges.size(); ++index) { @@ -76,30 +77,6 @@ void AllCases::Build() noexcept { }); } -/// Execute the build process in parallel without blocking. -void AllCases::BuildParallelAsync(Executor &&executor, Notifier &&callback) noexcept { - if (available_) { - return; // reduce consumption of mutex - } - building_.lock(); - if (available_) { - building_.unlock(); - return; // data is already available - } - auto counter = std::make_shared>(0); - auto all_done = std::make_shared(std::move(callback)); - for (auto head : case_heads()) { - executor([this, head, counter, all_done]() { - BuildCases(head, GetCases()[head]); - if (counter->fetch_add(1) == case_heads().size() - 1) { - available_ = true; - building_.unlock(); // release building mutex - all_done->operator()(); // trigger callback - } - }); - } -} - /// Execute the build process with parallel support and ensure thread safety. void AllCases::BuildParallel(Executor &&executor) noexcept { if (available_) { @@ -126,6 +103,32 @@ void AllCases::BuildParallel(Executor &&executor) noexcept { available_ = true; } +/// Execute the build process in parallel without blocking. +void AllCases::BuildParallelAsync(Executor &&executor, Notifier &&callback) noexcept { + if (available_) { + callback(); + return; // reduce consumption of mutex + } + building_.lock(); + if (available_) { + building_.unlock(); + callback(); + return; // data is already available + } + auto counter = std::make_shared>(0); + auto all_done = std::make_shared(std::move(callback)); + for (auto head : case_heads()) { + executor([this, head, counter, all_done]() { + BuildCases(head, GetCases()[head]); + if (counter->fetch_add(1) == case_heads().size() - 1) { // all tasks done + available_ = true; + building_.unlock(); // release building mutex + all_done->operator()(); // trigger callback + } + }); + } +} + RangesUnion& AllCases::GetCases() noexcept { static RangesUnion cases; return cases; diff --git a/src/core/all_cases/basic_ranges.cc b/src/core/all_cases/basic_ranges.cc index 0247e9a..1fc4b63 100644 --- a/src/core/all_cases/basic_ranges.cc +++ b/src/core/all_cases/basic_ranges.cc @@ -111,7 +111,7 @@ const Ranges& BasicRanges::Fetch() noexcept { } bool BasicRanges::IsAvailable() const noexcept { - return available_; + return available_; // no mutex required in one-way state } } // namespace cases diff --git a/src/core_test/CMakeLists.txt b/src/core_test/CMakeLists.txt index f86645a..3e70cb3 100644 --- a/src/core_test/CMakeLists.txt +++ b/src/core_test/CMakeLists.txt @@ -6,6 +6,8 @@ set(KLOTSKI_TEST_DEPS klotski-core gtest gtest_main md5sum) ############################################################################################### +include_directories(utils) + include_directories(${KLOTSKI_ROOT}/src/core/utils) include_directories(${KLOTSKI_ROOT}/src/core/all_cases) diff --git a/src/core_test/cases/all_cases.cc b/src/core_test/cases/all_cases.cc index d19ad04..990e111 100644 --- a/src/core_test/cases/all_cases.cc +++ b/src/core_test/cases/all_cases.cc @@ -1,8 +1,8 @@ #include -#include #include #include "md5sum.h" +#include "exposer.h" #include "all_cases.h" #include "gtest/gtest.h" #include "BS_thread_pool.hpp" @@ -23,63 +23,47 @@ static const auto TEST_THREAD_NUM = 256; static const std::string ALL_CASES_MD5 = "3888e9fab8d3cbb50908b12b147cfb23"; static const std::string BASIC_RANGES_MD5 = "6f385dc171e201089ff96bb010b47212"; -TEST(Cases, basic_ranges_build) { - std::vector threads; - threads.reserve(TEST_THREAD_NUM); - for (int i = 0; i < TEST_THREAD_NUM; ++i) { - threads.emplace_back([]() { - BasicRanges::Instance().Build(); - }); - } - for (auto &t : threads) { - t.join(); - } - EXPECT_TRUE(BasicRanges::Instance().IsAvailable()); +/// Forcibly modify private variables to reset state. +PRIVATE_ACCESS(AllCases, available_, bool) +PRIVATE_ACCESS(BasicRanges, available_, bool) + +/// Reset basic ranges build state, note it is thread-unsafe. +void basic_ranges_reset() { + access_BasicRanges_available_(BasicRanges::Instance()) = false; } -TEST(Cases, basic_ranges_size) { - auto &basic_ranges = BasicRanges::Instance().Fetch(); - EXPECT_EQ(basic_ranges.size(), BASIC_RANGES_NUM); +/// Reset all cases build state, note it is thread-unsafe. +void all_cases_reset() { + access_AllCases_available_(AllCases::Instance()) = false; } -TEST(Cases, basic_ranges_data) { +/// Verify that whether basic ranges data is correct. +void basic_ranges_verify() { + auto &basic_ranges = BasicRanges::Instance().Fetch(); + EXPECT_EQ(basic_ranges.size(), BASIC_RANGES_NUM); // verify basic ranges size + std::string basic_ranges_str; basic_ranges_str.reserve(BASIC_RANGES_NUM * 9); // 8-bit + '\n'` - for (auto range : BasicRanges::Instance().Fetch()) { + for (auto range : basic_ranges) { char *tmp = nullptr; asprintf(&tmp, "%08X\n", range); basic_ranges_str += tmp; } - EXPECT_EQ(md5sum(basic_ranges_str), BASIC_RANGES_MD5); -} - -TEST(Cases, all_cases_build) { - std::vector threads; - threads.reserve(TEST_THREAD_NUM); - for (int i = 0; i < TEST_THREAD_NUM; ++i) { - threads.emplace_back([]() { - AllCases::Instance().Build(); - }); - } - for (auto &t : threads) { - t.join(); - } - EXPECT_TRUE(AllCases::Instance().IsAvailable()); + EXPECT_EQ(md5sum(basic_ranges_str), BASIC_RANGES_MD5); // verify basic ranges checksum } -TEST(Cases, all_cases_size) { +/// Verify that whether all cases data is correct. +void all_cases_verify() { auto &all_cases = AllCases::Instance().Fetch(); for (int head = 0; head < 16; ++head) { - EXPECT_EQ(all_cases[head].size(), ALL_CASES_NUM[head]); + EXPECT_EQ(all_cases[head].size(), ALL_CASES_NUM[head]); // verify all cases size } auto all_cases_num = 0; - for (auto num : ALL_CASES_NUM) { - all_cases_num += num; - } - EXPECT_EQ(all_cases_num, ALL_CASES_NUM_); -} + std::for_each(all_cases.begin(), all_cases.end(), [&all_cases_num](auto &ranges) { + all_cases_num += ranges.size(); + }); + EXPECT_EQ(all_cases_num, ALL_CASES_NUM_); // verify all cases global size -TEST(Cases, all_cases_data) { std::string all_cases_str; all_cases_str.reserve(ALL_CASES_NUM_ * 10); // 9-bit + '\n' for (uint64_t head = 0; head < 16; ++head) { @@ -89,36 +73,132 @@ TEST(Cases, all_cases_data) { all_cases_str += tmp; } } - EXPECT_EQ(md5sum(all_cases_str), ALL_CASES_MD5); + EXPECT_EQ(md5sum(all_cases_str), ALL_CASES_MD5); // verify all cases checksum } -// TODO: test all_cases_parallel_build +TEST(Cases, basic_ranges) { + basic_ranges_reset(); + EXPECT_FALSE(BasicRanges::Instance().IsAvailable()); + BasicRanges::Instance().Build(); + EXPECT_TRUE(BasicRanges::Instance().IsAvailable()); + BasicRanges::Instance().Build(); + EXPECT_TRUE(BasicRanges::Instance().IsAvailable()); + basic_ranges_verify(); +} -TEST(Cases, thread_pool_demo) { +TEST(Cases, basic_ranges_mutex) { + basic_ranges_reset(); + BS::thread_pool pool(TEST_THREAD_NUM); - BasicRanges::Instance().Build(); + for (int i = 0; i < TEST_THREAD_NUM; ++i) { + auto _ = pool.submit(&BasicRanges::Build, &BasicRanges::Instance()); + } + EXPECT_FALSE(BasicRanges::Instance().IsAvailable()); + pool.wait_for_tasks(); + EXPECT_TRUE(BasicRanges::Instance().IsAvailable()); + basic_ranges_verify(); +} - BS::thread_pool pool; +TEST(Cases, all_cases) { + all_cases_reset(); + EXPECT_FALSE(AllCases::Instance().IsAvailable()); + AllCases::Instance().Build(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + AllCases::Instance().Build(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + all_cases_verify(); +} - std::cout << pool.get_thread_count() << std::endl; +TEST(Cases, all_cases_mutex) { + all_cases_reset(); + BS::thread_pool pool(TEST_THREAD_NUM); - auto start = clock(); - auto start_ = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < TEST_THREAD_NUM; ++i) { + auto _ = pool.submit(&AllCases::Build, &AllCases::Instance()); + } + EXPECT_FALSE(AllCases::Instance().IsAvailable()); + pool.wait_for_tasks(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + all_cases_verify(); +} - AllCases::Instance().BuildParallel([&pool](std::function &&func) { -// std::cout << "receive new task" << std::endl; - pool.push_task(func); +TEST(Cases, all_cases_parallel) { + all_cases_reset(); + BS::thread_pool executor; + EXPECT_FALSE(AllCases::Instance().IsAvailable()); + AllCases::Instance().BuildParallel([&executor](auto &&func) { + executor.push_task(func); }); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + AllCases::Instance().BuildParallel([&executor](auto &&func) { + executor.push_task(func); + }); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + all_cases_verify(); +} -// std::cout << "parallel build complete" << std::endl; +TEST(Cases, all_cases_parallel_mutex) { + all_cases_reset(); + BS::thread_pool executor; + BS::thread_pool pool(TEST_THREAD_NUM); + for (int i = 0; i < TEST_THREAD_NUM; ++i) { + auto _ = pool.submit(&AllCases::BuildParallel, &AllCases::Instance(), [&executor](auto &&func) { + executor.push_task(func); + }); + } + EXPECT_FALSE(AllCases::Instance().IsAvailable()); pool.wait_for_tasks(); + executor.wait_for_tasks(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + all_cases_verify(); +} - std::cerr << ((clock() - start) * 1000 / CLOCKS_PER_SEC) << "ms" << std::endl; - auto end = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast(end - start_); - std::cerr << elapsed.count() / 1000 / 1000 << "ms" << std::endl; +TEST(Cases, all_cases_async) { + all_cases_reset(); + BS::thread_pool executor; -// std::cout << "pool tasks complete" << std::endl; + std::promise promise_1; + auto future_1 = promise_1.get_future(); + AllCases::Instance().BuildParallelAsync([&executor](auto &&func) { + executor.push_task(func); + }, [&promise_1]() { + promise_1.set_value(); + }); + EXPECT_FALSE(AllCases::Instance().IsAvailable()); + future_1.wait(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + std::promise promise_2; + auto future_2 = promise_2.get_future(); + AllCases::Instance().BuildParallelAsync([&executor](auto &&func) { + executor.push_task(func); + }, [&promise_2]() { + promise_2.set_value(); + }); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + future_2.wait(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + all_cases_verify(); +} + +TEST(Cases, all_cases_async_mutex) { + all_cases_reset(); + BS::thread_pool executor; + std::atomic callback_num(0); + BS::thread_pool pool(TEST_THREAD_NUM); + + for (int i = 0; i < TEST_THREAD_NUM; ++i) { + auto _ = pool.submit(&AllCases::BuildParallelAsync, &AllCases::Instance(), [&executor](auto &&func) { + executor.push_task(func); + }, [&callback_num]() { + callback_num.fetch_add(1); + }); + } + EXPECT_FALSE(AllCases::Instance().IsAvailable()); + pool.wait_for_tasks(); + executor.wait_for_tasks(); + EXPECT_TRUE(AllCases::Instance().IsAvailable()); + EXPECT_EQ(callback_num.load(), TEST_THREAD_NUM); + all_cases_verify(); } diff --git a/src/core_test/utils/exposer.h b/src/core_test/utils/exposer.h new file mode 100644 index 0000000..e7ad736 --- /dev/null +++ b/src/core_test/utils/exposer.h @@ -0,0 +1,33 @@ +#pragma once + +/// The exposer can forcibly access private members of a class without changing +/// any code. It uses macros to construct a function that returns a reference +/// to the target member variable. + +namespace exposer { + +template +struct Exposer { + static T ptr; +}; + +template +T Exposer::ptr; + +template +struct ExposerImpl { + static struct Factory { + Factory() { Exposer::ptr = Ptr; } + } factory; +}; + +template +typename ExposerImpl::Factory ExposerImpl::factory; + +} // namespace exposer + +#define PRIVATE_ACCESS(Class, Member, Type) \ + template struct ::exposer::ExposerImpl; \ + Type& access_##Class##_##Member(Class &T) { \ + return T.*exposer::Exposer::ptr; \ + }