diff --git a/lib/include/srslte/common/multiqueue.h b/lib/include/srslte/common/multiqueue.h index b127ab67e..5b4c46be0 100644 --- a/lib/include/srslte/common/multiqueue.h +++ b/lib/include/srslte/common/multiqueue.h @@ -51,69 +51,103 @@ class multiqueue_handler using std::queue::size; using std::queue::empty; using std::queue::front; + + std::condition_variable cv_full; + bool active = true; }; public: explicit multiqueue_handler(uint32_t capacity_ = std::numeric_limits::max()) : capacity(capacity_) {} - ~multiqueue_handler() + ~multiqueue_handler() { reset(); } + + void reset() { - std::lock_guard lck(mutex); - queues_active.clear(); - queues.clear(); + std::unique_lock lock(mutex); running = false; + while (nof_threads_waiting > 0) { + uint32_t size = queues.size(); + lock.unlock(); + cv_empty.notify_one(); + for (uint32_t i = 0; i < size; ++i) { + queues[i].cv_full.notify_all(); + } + lock.lock(); + // wait for all threads to unblock + cv_exit.wait(lock); + } + queues.clear(); } int add_queue() { - uint32_t qidx = 0; - for (; qidx < queues_active.size() and queues_active[qidx]; ++qidx) + uint32_t qidx = 0; + std::lock_guard lock(mutex); + if (not running) { + return -1; + } + for (; qidx < queues.size() and queues[qidx].active; ++qidx) ; - if (qidx == queues_active.size()) { + if (qidx == queues.size()) { // create new queue - std::lock_guard lck(mutex); - queues_active.push_back(true); queues.emplace_back(); } else { - queues_active[qidx] = true; + queues[qidx].active = true; } return (int)qidx; } int nof_queues() { - std::lock_guard lck(mutex); - return std::count(queues_active.begin(), queues_active.end(), true); + std::lock_guard lock(mutex); + uint32_t count = 0; + for (uint32_t i = 0; i < queues.size(); ++i) { + count += queues[i].active ? 1 : 0; + } + return count; } - bool try_push(int q_idx, const myobj& value) + template + void push(int q_idx, FwdRef&& value) { - if (not running) { - return false; + { + std::unique_lock lock(mutex); + while (is_queue_active_(q_idx) and queues[q_idx].size() >= capacity) { + nof_threads_waiting++; + queues[q_idx].cv_full.wait(lock); + nof_threads_waiting--; + } + if (not is_queue_active_(q_idx)) { + cv_exit.notify_one(); + return; + } + queues[q_idx].push(std::forward(value)); } + cv_empty.notify_one(); + } + + bool try_push(int q_idx, const myobj& value) + { { - std::lock_guard lck(mutex); - if (queues[q_idx].size() >= capacity) { + std::lock_guard lock(mutex); + if (not is_queue_active_(q_idx) or queues[q_idx].size() >= capacity) { return false; } queues[q_idx].push(value); } - cv.notify_one(); + cv_empty.notify_one(); return true; } std::pair try_push(int q_idx, myobj&& value) { - if (not running) { - return {false, std::move(value)}; - } { std::lock_guard lck(mutex); - if (queues[q_idx].size() >= capacity) { + if (not is_queue_active_(q_idx) or queues[q_idx].size() >= capacity) { return {false, std::move(value)}; } queues[q_idx].push(std::move(value)); } - cv.notify_one(); + cv_empty.notify_one(); return {true, std::move(value)}; } @@ -121,19 +155,26 @@ public: { std::unique_lock lock(mutex); while (running) { - cv.wait(lock); // Round-robin for all queues - for (uint32_t i = 0; queues.size(); ++i) { + for (const queue_wrapper& q : queues) { spin_idx = (spin_idx + 1) % queues.size(); - if (queues_active[spin_idx] and not queues[spin_idx].empty()) { + if (is_queue_active_(spin_idx) and not queues[spin_idx].empty()) { if (value) { *value = std::move(queues[spin_idx].front()); } queues[spin_idx].pop(); + if (nof_threads_waiting > 0) { + lock.unlock(); + queues[spin_idx].cv_full.notify_one(); + } return spin_idx; } } + nof_threads_waiting++; + cv_empty.wait(lock); + nof_threads_waiting--; } + cv_exit.notify_one(); return -1; } @@ -152,14 +193,14 @@ public: const myobj& front(int qidx) { std::lock_guard lck(mutex); - return queues.front(); + return queues[qidx].front(); } void erase_queue(int qidx) { std::lock_guard lck(mutex); - if (queues_active[qidx]) { - queues_active[qidx] = false; + if (is_queue_active_(qidx)) { + queues[qidx].active = false; while (not queues[qidx].empty()) { queues[qidx].pop(); } @@ -169,17 +210,19 @@ public: bool is_queue_active(int qidx) { std::lock_guard lck(mutex); - return queues_active[qidx]; + return is_queue_active_(qidx); } private: + bool is_queue_active_(int qidx) const { return running and queues[qidx].active; } + std::mutex mutex; - std::condition_variable cv; + std::condition_variable cv_empty, cv_exit; uint32_t spin_idx = 0; bool running = true; - std::vector queues_active; std::vector queues; uint32_t capacity; + uint32_t nof_threads_waiting = 0; }; } // namespace srslte diff --git a/lib/test/common/queue_test.cc b/lib/test/common/queue_test.cc index 43c68a0e1..86b9fd94e 100644 --- a/lib/test/common/queue_test.cc +++ b/lib/test/common/queue_test.cc @@ -21,6 +21,8 @@ #include #include +#include +#include #define TESTASSERT(cond) \ { \ @@ -98,7 +100,117 @@ int test_multiqueue() return 0; } +int test_multiqueue_threading() +{ + std::cout << "\n===== TEST multiqueue threading test: start =====\n"; + + int capacity = 4, number, start_number = 2, nof_pushes = capacity + 1; + multiqueue_handler multiqueue(capacity); + int qid1 = multiqueue.add_queue(); + auto push_blocking_func = [&multiqueue](int qid, int start_value, int nof_pushes, bool* is_running) { + for (int i = 0; i < nof_pushes; ++i) { + multiqueue.push(qid, start_value + i); + std::cout << "t1: pushed item " << i << std::endl; + } + std::cout << "t1: pushed all items\n"; + *is_running = false; + }; + + bool t1_running = true; + std::thread t1(push_blocking_func, qid1, start_number, nof_pushes, &t1_running); + + TESTASSERT(t1_running) + usleep(1000); + TESTASSERT((int)multiqueue.size(qid1) == capacity) + + for (int i = 0; i < nof_pushes; ++i) { + TESTASSERT(multiqueue.wait_pop(&number) == qid1) + TESTASSERT(number == start_number + i) + std::cout << "main: popped item " << i << "\n"; + } + std::cout << "main: popped all items\n"; + usleep(1000); + TESTASSERT(not t1_running) + TESTASSERT(multiqueue.size(qid1) == 0) + + multiqueue.reset(); + t1.join(); + + std::cout << "outcome: Success\n"; + std::cout << "==================================================\n"; + + return 0; +} + +int test_multiqueue_threading2() +{ + std::cout << "\n===== TEST multiqueue threading test 2: start =====\n"; + // Description: push items until blocking in thread t1. Unblocks in main thread by calling multiqueue.reset() + + int capacity = 4, start_number = 2, nof_pushes = capacity + 1; + multiqueue_handler multiqueue(capacity); + int qid1 = multiqueue.add_queue(); + auto push_blocking_func = [&multiqueue](int qid, int start_value, int nof_pushes, bool* is_running) { + for (int i = 0; i < nof_pushes; ++i) { + multiqueue.push(qid, start_value + i); + } + std::cout << "t1: pushed all items\n"; + *is_running = false; + }; + + bool t1_running = true; + std::thread t1(push_blocking_func, qid1, start_number, nof_pushes, &t1_running); + + TESTASSERT(t1_running) + usleep(1000); + TESTASSERT((int)multiqueue.size(qid1) == capacity) + + multiqueue.reset(); + t1.join(); + + std::cout << "outcome: Success\n"; + std::cout << "===================================================\n"; + + return 0; +} + +int test_multiqueue_threading3() +{ + std::cout << "\n===== TEST multiqueue threading test 3: start =====\n"; + // pop will block in a separate thread, but multiqueue.reset() will unlock it + + int capacity = 4; + multiqueue_handler multiqueue(capacity); + int qid1 = multiqueue.add_queue(); + auto pop_blocking_func = [&multiqueue](int qid, bool* success) { + int number; + int id = multiqueue.wait_pop(&number); + *success = id < 0; + }; + + bool t1_success = false; + std::thread t1(pop_blocking_func, qid1, &t1_success); + + TESTASSERT(not t1_success) + usleep(1000); + TESTASSERT(not t1_success) + TESTASSERT((int)multiqueue.size(qid1) == 0) + + // Should be able to unlock all + multiqueue.reset(); + t1.join(); + TESTASSERT(t1_success) + + std::cout << "outcome: Success\n"; + std::cout << "===================================================\n"; + + return 0; +} + int main() { TESTASSERT(test_multiqueue() == 0); + TESTASSERT(test_multiqueue_threading() == 0); + TESTASSERT(test_multiqueue_threading2() == 0); + TESTASSERT(test_multiqueue_threading3() == 0); }