diff --git a/lib/include/srslte/adt/observer.h b/lib/include/srslte/adt/observer.h new file mode 100644 index 000000000..6ccd07407 --- /dev/null +++ b/lib/include/srslte/adt/observer.h @@ -0,0 +1,212 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#ifndef SRSLTE_OBSERVER_H +#define SRSLTE_OBSERVER_H + +#include +#include +#include +#include + +namespace srslte { + +using observer_id = std::size_t; +const size_t invalid_observer_id = std::numeric_limits::max(); + +template +class base_observable +{ +public: + using callback_t = std::function; + + //! Subscribe Observer that is a callback + template + typename std::enable_if::value, observer_id>::type + subscribe(Callable&& callable) + { + return subscribe_common(callable); + } + + //! Subscribe Observer type with method Observer::trigger(Args...) + template + typename std::enable_if::value, observer_id>::type + subscribe(Observer& observer) + { + return subscribe_common([&observer](Args... args) { observer.trigger(std::forward(args)...); }); + } + + //! Subscribe Observer type with custom trigger method + template + observer_id subscribe(Observer& observer, void (Observer::*trigger_method)(Args...)) + { + return subscribe_common( + [&observer, trigger_method](Args... args) { (observer.*trigger_method)(std::forward(args)...); }); + } + + //! Unsubscribe Observer + bool unsubscribe(observer_id id) + { + if (id < observers.size() and static_cast(observers[id])) { + observers[id] = nullptr; + return true; + } + return false; + } + + size_t nof_observers() const + { + size_t count = 0; + for (auto& slot : observers) { + count += static_cast(slot) ? 1 : 0; + } + return count; + } + + //! Signal result to observers + void dispatch(Args... args) + { + for (auto& obs_callback : observers) { + if (obs_callback) { + obs_callback(std::forward(args)...); + } + } + } + +protected: + using observer_list_t = std::deque; + + ~base_observable() = default; + + template + observer_id subscribe_common(Callable&& callable) + { + size_t id = 0; + for (auto& slot : observers) { + if (not static_cast(slot)) { + // empty slot found + slot = std::forward(callable); + return id; + } + id++; + } + // append to end of list + observers.emplace_back(std::forward(callable)); + return observers.size() - 1; + } + + observer_list_t observers; +}; + +template +class observable : public base_observable +{}; + +//! Special case of observable for event types +template +class event_dispatcher : public base_observable +{}; + +//! Event Subject that enqueues events and only signals observers when ::process() is called +template +class event_queue : public base_observable +{ + using base_t = base_observable; + +public: + template + void enqueue(Args&&... args) + { + pending_events.emplace_back(std::forward(args)...); + } + + void process() + { + for (auto& ev : pending_events) { + base_t::dispatch(ev); + } + pending_events.clear(); + } + +private: + // forbid direct dispatches + using base_t::dispatch; + + std::vector pending_events; +}; + +//! RAII class to automatically unsubscribe an observer from an Event +template +class unique_observer_id +{ + using subject_t = base_observable; + +public: + unique_observer_id(subject_t& parent_, observer_id id_) : parent(&parent_), id(id_) {} + template + unique_observer_id(subject_t& parent_, T&& callable) : parent(&parent_) + { + id = parent->subscribe(std::forward(callable)); + } + template + unique_observer_id(subject_t& parent_, Observer& observer, void (Observer::*trigger_method)(const Event&)) : + parent(&parent_) + { + id = parent->subscribe(observer, trigger_method); + } + unique_observer_id(unique_observer_id&& other) noexcept : parent(other.parent), id(other.id) + { + other.parent = nullptr; + } + unique_observer_id(const unique_observer_id& other) = delete; + + unique_observer_id& operator=(unique_observer_id&& other) noexcept + { + parent = other.parent; + id = other.id; + other.id = invalid_observer_id; + return *this; + } + unique_observer_id& operator=(const unique_observer_id& other) = delete; + ~unique_observer_id() + { + if (id != invalid_observer_id) { + parent->unsubscribe(id); + } + } + + observer_id get_id() const { return id; } + bool is_valid() const { return id != invalid_observer_id; } + observer_id release() + { + observer_id ret = id; + id = invalid_observer_id; + return ret; + } + +private: + subject_t* parent; + observer_id id; +}; + +} // namespace srslte + +#endif // SRSLTE_OBSERVER_H diff --git a/lib/test/adt/CMakeLists.txt b/lib/test/adt/CMakeLists.txt index 25b575f75..e9cc3cb59 100644 --- a/lib/test/adt/CMakeLists.txt +++ b/lib/test/adt/CMakeLists.txt @@ -41,3 +41,7 @@ add_test(span_test span_test) add_executable(interval_test interval_test.cc) target_link_libraries(interval_test srslte_common) add_test(interval_test interval_test) + +add_executable(observer_test observer_test.cc) +target_link_libraries(observer_test srslte_common) +add_test(observer_test observer_test) diff --git a/lib/test/adt/observer_test.cc b/lib/test/adt/observer_test.cc new file mode 100644 index 000000000..a416e6034 --- /dev/null +++ b/lib/test/adt/observer_test.cc @@ -0,0 +1,256 @@ +/* + * Copyright 2013-2020 Software Radio Systems Limited + * + * This file is part of srsLTE. + * + * srsLTE is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of + * the License, or (at your option) any later version. + * + * srsLTE is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * A copy of the GNU Affero General Public License can be found in + * the LICENSE file in the top-level directory of this distribution + * and at http://www.gnu.org/licenses/. + * + */ + +#include "srslte/adt/observer.h" +#include "srslte/common/test_common.h" + +struct M { + M() = default; + explicit M(int v) : val(v) {} + M(M&&) noexcept = default; + M(const M&) = delete; + M& operator=(M&&) noexcept = default; + M& operator=(const M&) = delete; + + M& operator+=(int i) + { + val += i; + return *this; + } + + int val = 0; +}; + +struct lval_observer_tester { + void trigger(M v_) { v = std::move(v_); } + void foo(M v_) + { + v = std::move(v_); + v += 1; + } + M v; +}; +struct cref_observer_tester { + void trigger(const M& v_) { v.val = v_.val; } + void foo(const M& v_) { v.val = v_.val + 1; } + M v; +}; +struct lref_observer_tester { + void trigger(M& v_) { v.val = v_.val; } + void foo(M& v_) { v.val = v_.val + 1; } + M v; +}; +struct rref_observer_tester { + void trigger(M&& v_) { v = std::move(v_); } + void foo(M&& v_) + { + v = std::move(v_); + v += 1; + } + M v; +}; + +int observable_test() +{ + // TEST l-value arguments passed by value + { + M val; + srslte::observable subject; + TESTASSERT(subject.nof_observers() == 0); + + srslte::observer_id id1 = subject.subscribe([&val](M v) { val = std::move(v); }); + + lval_observer_tester observer{}, observer2{}; + srslte::observer_id id2 = subject.subscribe(observer); + srslte::observer_id id3 = subject.subscribe(observer2, &lval_observer_tester::foo); + + TESTASSERT(subject.nof_observers() == 3); + TESTASSERT(val.val == 0); + subject.dispatch(M{5}); + TESTASSERT(val.val == 5); + TESTASSERT(observer.v.val == 5); + TESTASSERT(observer2.v.val == 6); + + subject.unsubscribe(id1); + TESTASSERT(subject.nof_observers() == 2); + subject.unsubscribe(id2); + TESTASSERT(subject.nof_observers() == 1); + subject.unsubscribe(id3); + TESTASSERT(subject.nof_observers() == 0); + } + + // Test l-value arguments passed by const ref + { + M val; + srslte::observable subject; + TESTASSERT(subject.nof_observers() == 0); + + subject.subscribe([&val](const M& v) { val.val = v.val; }); + + cref_observer_tester observer{}, observer2{}; + subject.subscribe(observer); + subject.subscribe(observer2, &cref_observer_tester::foo); + + M new_val{6}; + subject.dispatch(new_val); + TESTASSERT(val.val == 6); + TESTASSERT(observer.v.val == 6); + TESTASSERT(observer2.v.val == 7); + } + + // Test l-value arguments passed by ref + { + M val; + srslte::observable subject; + TESTASSERT(subject.nof_observers() == 0); + + subject.subscribe([&val](M& v) { val = std::move(v); }); + + lref_observer_tester observer{}, observer2{}; + subject.subscribe(observer); + subject.subscribe(observer2, &lref_observer_tester::foo); + + M new_val{6}; + subject.dispatch(new_val); + TESTASSERT(val.val == 6); + TESTASSERT(observer.v.val == 6); + TESTASSERT(observer2.v.val == 7); + } + + // Test r-value arguments + { + M val; + srslte::observable subject; + TESTASSERT(subject.nof_observers() == 0); + + srslte::observer_id id1 = subject.subscribe([&val](M&& v) { val = std::move(v); }); + + rref_observer_tester observer{}, observer2{}; + srslte::observer_id id2 = subject.subscribe(observer); + srslte::observer_id id3 = subject.subscribe(observer2, &rref_observer_tester::foo); + + subject.dispatch(M{3}); + TESTASSERT(val.val == 3); + TESTASSERT(observer.v.val == 3); + TESTASSERT(observer2.v.val == 4); + + subject.unsubscribe(id1); + subject.unsubscribe(id2); + subject.unsubscribe(id3); + TESTASSERT(subject.nof_observers() == 0); + } + + return SRSLTE_SUCCESS; +} + +int event_dispatcher_test() +{ + srslte::event_dispatcher signaller; + + M val; + signaller.subscribe([&val](const M& ev) { val.val = ev.val; }); + + cref_observer_tester observer, observer2; + signaller.subscribe(observer); + signaller.subscribe(observer2, &cref_observer_tester::foo); + + TESTASSERT(val.val == 0); + TESTASSERT(observer.v.val == 0); + TESTASSERT(observer2.v.val == 0); + signaller.dispatch(M{2}); + TESTASSERT(val.val == 2); + TESTASSERT(observer.v.val == 2); + TESTASSERT(observer2.v.val == 3); + + val.val = 1; + observer.v.val = 0; + observer2.v.val = 5; + signaller.dispatch(M{2}); + TESTASSERT(val.val == 2); + TESTASSERT(observer.v.val == 2); + TESTASSERT(observer2.v.val == 3); + + return SRSLTE_SUCCESS; +} + +int event_queue_test() +{ + srslte::event_queue signaller; + + M val; + signaller.subscribe([&val](const M& ev) { val.val = ev.val; }); + cref_observer_tester observer, observer2; + signaller.subscribe(observer); + signaller.subscribe(observer2, &cref_observer_tester::foo); + + TESTASSERT(val.val == 0); + TESTASSERT(observer.v.val == 0); + TESTASSERT(observer2.v.val == 0); + signaller.enqueue(M{2}); + TESTASSERT(val.val == 0); + TESTASSERT(observer.v.val == 0); + TESTASSERT(observer2.v.val == 0); + signaller.process(); + TESTASSERT(val.val == 2); + TESTASSERT(observer.v.val == 2); + TESTASSERT(observer2.v.val == 3); + + return SRSLTE_SUCCESS; +} + +int unique_subscribe_test() +{ + { + srslte::event_dispatcher signaller; + cref_observer_tester observer; + TESTASSERT(signaller.nof_observers() == 0); + { + srslte::unique_observer_id obs{signaller, observer}; + TESTASSERT(signaller.nof_observers() == 1); + } + TESTASSERT(signaller.nof_observers() == 0); + } + + { + srslte::event_queue signaller; + cref_observer_tester observer; + TESTASSERT(signaller.nof_observers() == 0); + { + srslte::unique_observer_id obs{signaller, observer}; + TESTASSERT(signaller.nof_observers() == 1); + } + TESTASSERT(signaller.nof_observers() == 0); + } + + return SRSLTE_SUCCESS; +} + +int main() +{ + TESTASSERT(observable_test() == SRSLTE_SUCCESS); + TESTASSERT(event_dispatcher_test() == SRSLTE_SUCCESS); + TESTASSERT(event_queue_test() == SRSLTE_SUCCESS); + TESTASSERT(unique_subscribe_test() == SRSLTE_SUCCESS); + + printf("Success\n"); + + return 0; +}