From c413fadea9b66f2534fb19c2584ffc0c4a75db0c Mon Sep 17 00:00:00 2001 From: Francisco Paisana Date: Fri, 20 Sep 2019 15:43:26 +0100 Subject: [PATCH] created a queue-based thread pool. Tasks are inserted into a queue and then popped by the thread pool workers to be processed. --- lib/include/srslte/common/thread_pool.h | 55 ++++++++++- lib/include/srslte/common/threads.h | 8 +- lib/src/common/thread_pool.cc | 123 ++++++++++++++++++++++++ lib/test/common/queue_test.cc | 97 ++++++++++++++++++- 4 files changed, 275 insertions(+), 8 deletions(-) diff --git a/lib/include/srslte/common/thread_pool.h b/lib/include/srslte/common/thread_pool.h index b439918f0..c237ce897 100644 --- a/lib/include/srslte/common/thread_pool.h +++ b/lib/include/srslte/common/thread_pool.h @@ -29,10 +29,15 @@ #ifndef SRSLTE_THREAD_POOL_H #define SRSLTE_THREAD_POOL_H +#include +#include +#include +#include +#include +#include #include #include #include -#include #include "srslte/common/threads.h" @@ -94,8 +99,50 @@ private: std::vector status; std::vector cvar; std::vector mutex; - std::stack available_workers; + std::stack available_workers; }; -} - + +class task_thread_pool +{ + using task_t = std::function; + +public: + task_thread_pool(uint32_t nof_workers); + ~task_thread_pool(); + void start(int32_t prio = -1, uint32_t mask = 255); + void stop(); + + void push_task(const task_t& task); + void push_task(task_t&& task); + uint32_t nof_pending_tasks(); + +private: + class worker_t : public thread + { + public: + explicit worker_t(task_thread_pool* parent_, uint32_t id); + void setup(int32_t prio, uint32_t mask); + bool is_running() const { return running; } + uint32_t id() const { return id_; } + + void run_thread() override; + + private: + bool wait_task(task_t* task); + + task_thread_pool* parent = nullptr; + uint32_t id_ = 0; + bool running = false; + }; + + std::queue pending_tasks; + std::vector workers; + std::mutex queue_mutex; + std::condition_variable cv_empty, cv_exit; + bool running; + uint32_t nof_workers_running = 0; +}; + +} // namespace srslte + #endif // SRSLTE_THREAD_POOL_H diff --git a/lib/include/srslte/common/threads.h b/lib/include/srslte/common/threads.h index 7313ed1d1..e47e9c070 100644 --- a/lib/include/srslte/common/threads.h +++ b/lib/include/srslte/common/threads.h @@ -50,9 +50,11 @@ class thread { public: thread(const std::string& name_) : _thread(0), name(name_) {} - bool start(int prio = -1) { - return threads_new_rt_prio(&_thread, thread_function_entry, this, prio); - } + thread(const thread&) = delete; + thread(thread&&) noexcept = default; + thread& operator=(const thread&) = delete; + thread& operator=(thread&&) noexcept = default; + bool start(int prio = -1) { return threads_new_rt_prio(&_thread, thread_function_entry, this, prio); } bool start_cpu(int prio, int cpu) { return threads_new_rt_cpu(&_thread, thread_function_entry, this, cpu, prio); } diff --git a/lib/src/common/thread_pool.cc b/lib/src/common/thread_pool.cc index 1eddae834..122fec2ec 100644 --- a/lib/src/common/thread_pool.cc +++ b/lib/src/common/thread_pool.cc @@ -280,6 +280,129 @@ uint32_t thread_pool::get_nof_workers() return nof_workers; } +/************************************************************************** + * task_thread_pool - uses a queue to enqueue callables, that start + * once a worker is available + *************************************************************************/ + +task_thread_pool::task_thread_pool(uint32_t nof_workers) : running(false) +{ + workers.reserve(nof_workers); + for (uint32_t i = 0; i < nof_workers; ++i) { + workers.emplace_back(this, i); + } +} + +task_thread_pool::~task_thread_pool() +{ + stop(); +} + +void task_thread_pool::start(int32_t prio, uint32_t mask) +{ + std::lock_guard lock(queue_mutex); + running = true; + for (worker_t& w : workers) { + w.setup(prio, mask); + } +} + +void task_thread_pool::stop() +{ + std::unique_lock lock(queue_mutex); + running = false; + do { + nof_workers_running = 0; + // next worker that is still running + for (worker_t& w : workers) { + if (w.is_running()) { + nof_workers_running++; + } + } + if (nof_workers_running > 0) { + lock.unlock(); + cv_empty.notify_all(); + lock.lock(); + cv_exit.wait(lock); + } + } while (nof_workers_running > 0); +} + +void task_thread_pool::push_task(const task_t& task) +{ + { + std::lock_guard lock(queue_mutex); + pending_tasks.push(task); + } + cv_empty.notify_one(); +} + +void task_thread_pool::push_task(task_t&& task) +{ + { + std::lock_guard lock(queue_mutex); + pending_tasks.push(std::move(task)); + } + cv_empty.notify_one(); } +uint32_t task_thread_pool::nof_pending_tasks() +{ + std::lock_guard lock(queue_mutex); + return pending_tasks.size(); +} + +task_thread_pool::worker_t::worker_t(srslte::task_thread_pool* parent_, uint32_t my_id) : + parent(parent_), + thread("TASKWORKER"), + id_(my_id) +{ + set_name(std::string("TASKWORKER") + std::to_string(my_id)); +} + +void task_thread_pool::worker_t::setup(int32_t prio, uint32_t mask) +{ + if (mask == 255) { + start(prio); + } else { + start_cpu_mask(prio, mask); + } +} + +bool task_thread_pool::worker_t::wait_task(task_t* task) +{ + std::unique_lock lock(parent->queue_mutex); + while (parent->running and parent->pending_tasks.empty()) { + parent->cv_empty.wait(lock); + } + if (not parent->running) { + return false; + } + if (task) { + *task = std::move(parent->pending_tasks.front()); + } + parent->pending_tasks.pop(); + return true; +} + +void task_thread_pool::worker_t::run_thread() +{ + running = true; + + // main loop + task_t task; + while (wait_task(&task)) { + task(id()); + } + + // on exit, notify pool class + std::unique_lock lock(parent->queue_mutex); + running = false; + parent->nof_workers_running--; + if (parent->nof_workers_running == 0) { + lock.unlock(); + parent->cv_exit.notify_one(); + } +} +} // namespace srslte diff --git a/lib/test/common/queue_test.cc b/lib/test/common/queue_test.cc index 72b66630a..bca5aec1c 100644 --- a/lib/test/common/queue_test.cc +++ b/lib/test/common/queue_test.cc @@ -19,8 +19,9 @@ * */ +#include "srslte/common/multiqueue.h" +#include "srslte/common/thread_pool.h" #include -#include #include #include @@ -214,10 +215,104 @@ int test_multiqueue_threading3() return 0; } +int test_task_thread_pool() +{ + std::cout << "\n====== TEST task thread pool test 1: start ======\n"; + // Description: check whether the tasks are successfully distributed between workers + + uint32_t nof_workers = 4, nof_runs = 10000; + std::vector count_worker(nof_workers, 0); + std::vector count_mutex(nof_workers); + + task_thread_pool thread_pool(nof_workers); + thread_pool.start(); + + auto task = [&count_worker, &count_mutex](uint32_t worker_id) { + std::lock_guard lock(count_mutex[worker_id]); + // std::cout << "hello world from worker " << worker_id << std::endl; + count_worker[worker_id]++; + }; + + for (uint32_t i = 0; i < nof_runs; ++i) { + thread_pool.push_task(task); + } + + // wait for all tasks to be successfully processed + while (thread_pool.nof_pending_tasks() > 0) { + usleep(100); + } + + thread_pool.stop(); + + uint32_t total_count = 0; + for (uint32_t i = 0; i < nof_workers; ++i) { + if (count_worker[i] < 10) { + printf("the number of tasks %d assigned to worker %d is too low\n", count_worker[i], i); + return -1; + } + total_count += count_worker[i]; + printf("worker %d: %d runs\n", i, count_worker[i]); + } + if (total_count != nof_runs) { + printf("Number of task runs=%d does not match total=%d\n", total_count, nof_runs); + return -1; + } + + std::cout << "outcome: Success\n"; + std::cout << "===================================================\n"; + return 0; +} + +int test_task_thread_pool2() +{ + std::cout << "\n====== TEST task thread pool test 2: start ======\n"; + // Description: push a very long task to all workers, and call thread_pool.stop() to check if it waits for the tasks + // to be completed, and does not get stuck. + + uint32_t nof_workers = 4; + uint8_t workers_started = 0, workers_finished = 0; + std::mutex mut; + + task_thread_pool thread_pool(nof_workers); + thread_pool.start(); + + auto task = [&workers_started, &workers_finished, &mut](uint32_t worker_id) { + { + std::lock_guard lock(mut); + workers_started++; + } + sleep(1); + std::lock_guard lock(mut); + std::cout << "worker " << worker_id << " has finished\n"; + workers_finished++; + }; + + for (uint32_t i = 0; i < nof_workers; ++i) { + thread_pool.push_task(task); + } + + while (workers_started != nof_workers) { + usleep(10); + } + + std::cout << "stopping thread pool...\n"; + thread_pool.stop(); + std::cout << "thread pool stopped.\n"; + + TESTASSERT(workers_finished == nof_workers); + + std::cout << "outcome: Success\n"; + std::cout << "===================================================\n"; + return 0; +} + int main() { TESTASSERT(test_multiqueue() == 0); TESTASSERT(test_multiqueue_threading() == 0); TESTASSERT(test_multiqueue_threading2() == 0); TESTASSERT(test_multiqueue_threading3() == 0); + + TESTASSERT(test_task_thread_pool() == 0); + TESTASSERT(test_task_thread_pool2() == 0); }