diff --git a/lib/include/srslte/common/network_utils.h b/lib/include/srslte/common/network_utils.h index 84f7e122e..bccfa2d30 100644 --- a/lib/include/srslte/common/network_utils.h +++ b/lib/include/srslte/common/network_utils.h @@ -36,74 +36,93 @@ namespace srslte { -class rx_sctp_socket_ref; +class rx_sctp_socket_ref_t; /** - * @brief handles the lifetime of a SCTP socket and provides convenience methods for listening/connecting, and read/send + * Description: Class created for code reuse by different sockets */ -class sctp_socket +class base_socket_t { public: - sctp_socket(); - sctp_socket(sctp_socket&&) noexcept; - sctp_socket(const sctp_socket&) = delete; - ~sctp_socket(); - sctp_socket& operator=(sctp_socket&&) noexcept; - sctp_socket& operator=(const sctp_socket&) = delete; + base_socket_t() = default; + base_socket_t(const base_socket_t&) = delete; + base_socket_t(base_socket_t&& other) noexcept; + virtual ~base_socket_t(); + base_socket_t& operator=(const base_socket_t&) = delete; + base_socket_t& operator =(base_socket_t&&) noexcept; + + bool is_init() const { return sockfd >= 0; } + int fd() const { return sockfd; } + + // generic read/write interface + virtual int read(void* buf, size_t nbytes) const = 0; + virtual int send(void* buf, size_t nbytes) const = 0; + +protected: + void reset_(); + int bind_addr(const char* bind_addr_str, int port); + virtual int create_socket() = 0; + int connect_to(struct sockaddr_in* dest_addr, const char* dest_addr_str, int dest_port); + + int sockfd = -1; + struct sockaddr_in addr_in = {}; +}; +/** + * Description: handles the lifetime of a SCTP socket and provides convenience methods for listening/connecting, and + * read/send + */ +class sctp_socket_t final : public base_socket_t +{ +public: void reset(); int listen_addr(const char* bind_addr_str, int port); int connect_addr(const char* bind_addr_str, const char* dest_addr_str, int dest_port); - int read(void* buf, - ssize_t nbytes, - struct sockaddr_in* from = nullptr, - socklen_t fromlen = sizeof(sockaddr_in), - struct sctp_sndrcvinfo* sinfo = nullptr, - int msg_flags = 0); - int send(void* buf, ssize_t nbytes, uint32_t ppid, uint32_t stream_id); + int read_from(void* buf, + size_t nbytes, + struct sockaddr_in* from = nullptr, + socklen_t* fromlen = nullptr, + struct sctp_sndrcvinfo* sinfo = nullptr, + int msg_flags = 0) const; + int send(void* buf, size_t nbytes, uint32_t ppid, uint32_t stream_id) const; - const struct sockaddr_in& get_sockaddr_in() const { return addr_in; } - int fd() const { return sockfd; } - operator rx_sctp_socket_ref(); ///< cast to rx_sctp_socket_ref is safe + int read(void* buf, size_t nbytes) const override { return read_from(buf, nbytes, nullptr, nullptr, nullptr, 0); } + int send(void* buf, size_t nbytes) const override + { + printf("SCTP interface send is invalid\n"); + return -1; + } private: - int create_socket(); - int bind_addr(const char* bind_addr_str, int port = 0); + int create_socket() override; - int sockfd = -1; - struct sockaddr_in addr_in; - struct sockaddr_in dest_addr; + struct sockaddr_in dest_addr = {}; }; -/** - * @brief The rx_sctp_socket_ref class is a safe inteface/handler for receiving SCTP packets - * it basically forbids the user from trying to reset the socket while it is still - * registered to the rx_multisocket_handler for instance. - */ -class rx_sctp_socket_ref +class tcp_socket_t final : public base_socket_t { public: - rx_sctp_socket_ref(sctp_socket* sock_) : sock(sock_) {} - int read(void* buf, - ssize_t nbytes, - struct sockaddr_in* from = nullptr, - socklen_t fromlen = sizeof(sockaddr_in), - struct sctp_sndrcvinfo* sinfo = nullptr, - int msg_flags = 0) - { - return sock->read(buf, nbytes, from, fromlen, sinfo, msg_flags); - } - int fd() const { return sock->fd(); } + void reset(); + int listen_addr(const char* bind_addr_str, int port); + int accept_connection(); + int connect_addr(const char* bind_addr_str, const char* dest_addr_str, int dest_port); + + int read(void* buf, size_t nbytes) const override; + int send(void* buf, size_t nbytes) const override; private: - sctp_socket* sock = nullptr; + int create_socket() override; + + struct sockaddr_in dest_addr = {}; + int connfd = -1; }; class rx_multisocket_handler final : public thread { public: - using callback_t = std::function; + using sctp_callback_t = std::function; + using tcp_callback_t = std::function; rx_multisocket_handler(std::string name_, srslte::log* log_); rx_multisocket_handler(rx_multisocket_handler&&) = delete; @@ -112,7 +131,14 @@ public: rx_multisocket_handler& operator=(const rx_multisocket_handler&&) = delete; ~rx_multisocket_handler(); - bool register_sctp_socket(rx_sctp_socket_ref sock, callback_t recv_handler_); + template + bool register_socket(const Sock& s, Handler&& handler) + { + auto func = [&s, handler]() { handler(s); }; + return register_socket_(std::pair >(s.fd(), func)); + } + // bool register_sctp_socket(const sctp_socket_t& sock, const sctp_callback_t& recv_handler_); + // bool register_tcp_socket(const tcp_socket_t& sock, const tcp_callback_t& recv_handler_); void run_thread() override; @@ -124,19 +150,18 @@ private: cmd_id_t cmd = cmd_id_t::EXIT; int new_fd = -1; }; - struct sctp_handler_t { - callback_t callback; - rx_sctp_socket_ref sctp_ptr; - }; + + bool register_socket_(std::pair >&& elem); + // args std::string name; srslte::log* log_h = nullptr; // state - std::mutex socket_mutex; - std::map active_sctp_sockets; - bool running = false; - int pipefd[2] = {}; + std::mutex socket_mutex; + std::map > active_sockets; + bool running = false; + int pipefd[2] = {}; }; } // namespace srslte diff --git a/lib/src/common/network_utils.cc b/lib/src/common/network_utils.cc index 3325ff747..60367f484 100644 --- a/lib/src/common/network_utils.cc +++ b/lib/src/common/network_utils.cc @@ -30,163 +30,250 @@ namespace srslte { -sctp_socket::sctp_socket() -{ - bzero(&addr_in, sizeof(addr_in)); - bzero(&dest_addr, sizeof(dest_addr)); -} - -sctp_socket::sctp_socket(sctp_socket&& other) noexcept +base_socket_t::base_socket_t(base_socket_t&& other) noexcept { sockfd = other.sockfd; memcpy(&addr_in, &other.addr_in, sizeof(addr_in)); - // reset other without calling close - other.sockfd = -1; + other.sockfd = 0; bzero(&other.addr_in, sizeof(other.addr_in)); } - -sctp_socket::~sctp_socket() +base_socket_t::~base_socket_t() { - reset(); + if (sockfd >= 0) { + close(sockfd); + } } - -sctp_socket& sctp_socket::operator=(sctp_socket&& other) noexcept +base_socket_t& base_socket_t::operator=(base_socket_t&& other) noexcept { if (this == &other) { return *this; } - sockfd = other.sockfd; memcpy(&addr_in, &other.addr_in, sizeof(addr_in)); - other.sockfd = -1; + sockfd = other.sockfd; bzero(&other.addr_in, sizeof(other.addr_in)); + other.sockfd = 0; return *this; } -void sctp_socket::reset() +void base_socket_t::reset_() { if (sockfd >= 0) { close(sockfd); } - bzero(&addr_in, sizeof(addr_in)); - bzero(&dest_addr, sizeof(dest_addr)); + addr_in = {}; } -int sctp_socket::listen_addr(const char* bind_addr_str, int port) +int base_socket_t::bind_addr(const char* bind_addr_str, int port) { if (sockfd < 0) { - if (create_socket()) { + if (create_socket() != 0) { return -1; } } + addr_in.sin_family = AF_INET; + addr_in.sin_port = (port != 0) ? htons(port) : 0; + if (inet_pton(AF_INET, bind_addr_str, &(addr_in.sin_addr)) != 1) { + perror("inet_pton"); + return -1; + } + + if (bind(sockfd, (struct sockaddr*)&addr_in, sizeof(addr_in)) != 0) { + perror("bind()"); + return -1; + } + return 0; +} + +int base_socket_t::connect_to(struct sockaddr_in* dest_addr, const char* dest_addr_str, int dest_port) +{ + dest_addr->sin_family = AF_INET; + dest_addr->sin_port = htons(dest_port); + if (inet_pton(AF_INET, dest_addr_str, &(dest_addr->sin_addr)) != 1) { + perror("inet_pton()"); + return -1; + } + if (connect(sockfd, (struct sockaddr*)dest_addr, sizeof(*dest_addr)) == -1) { + perror("connect()"); + return -1; + } + return 0; +} + +/*********************************************************************** + * SCTP socket + **********************************************************************/ + +void sctp_socket_t::reset() +{ + reset_(); + dest_addr = {}; +} + +int sctp_socket_t::listen_addr(const char* bind_addr_str, int port) +{ + if (sockfd < 0 and create_socket() != 0) { + reset(); + return SRSLTE_ERROR; + } + // Sets the data_io_event to be able to use sendrecv_info // Subscribes to the SCTP_SHUTDOWN event, to handle graceful shutdown - struct sctp_event_subscribe evnts; - bzero(&evnts, sizeof(evnts)); - evnts.sctp_data_io_event = 1; - evnts.sctp_shutdown_event = 1; - if (setsockopt(sockfd, IPPROTO_SCTP, SCTP_EVENTS, &evnts, sizeof(evnts))) { + struct sctp_event_subscribe evnts = {}; + evnts.sctp_data_io_event = 1; + evnts.sctp_shutdown_event = 1; + if (setsockopt(sockfd, IPPROTO_SCTP, SCTP_EVENTS, &evnts, sizeof(evnts)) != 0) { perror("setsockopt"); reset(); - return -1; + return SRSLTE_ERROR; } // bind addr - if (bind_addr(bind_addr_str, port)) { + if (bind_addr(bind_addr_str, port) != 0) { reset(); - return -1; + return SRSLTE_ERROR; } // Listen for connections - if (listen(sockfd, SOMAXCONN)) { + if (listen(sockfd, SOMAXCONN) != 0) { perror("listen"); - reset(); - return -1; + return SRSLTE_ERROR; } - return 0; + return SRSLTE_SUCCESS; } -int sctp_socket::connect_addr(const char* bind_addr_str, const char* dest_addr_str, int dest_port) +int sctp_socket_t::connect_addr(const char* bind_addr_str, const char* dest_addr_str, int dest_port) { - if (sockfd < 0) { - if (bind_addr(bind_addr_str, 0)) { - return -1; - } + if (sockfd < 0 and bind_addr(bind_addr_str, 0) != 0) { + reset(); + return SRSLTE_ERROR; } - dest_addr.sin_family = AF_INET; - dest_addr.sin_port = htons(dest_port); - if (inet_pton(AF_INET, dest_addr_str, &(dest_addr.sin_addr)) != 1) { - perror("inet_pton()"); - return -1; - } - if (connect(sockfd, (struct sockaddr*)&dest_addr, sizeof(dest_addr)) == -1) { - perror("connect()"); - return -1; + if (connect_to(&dest_addr, dest_addr_str, dest_port) != 0) { + return SRSLTE_ERROR; } - return 0; + return SRSLTE_SUCCESS; } -int sctp_socket::read(void* buf, - ssize_t nbytes, - struct sockaddr_in* from, - socklen_t fromlen, - struct sctp_sndrcvinfo* sinfo, - int msg_flags) +int sctp_socket_t::read_from(void* buf, + size_t nbytes, + struct sockaddr_in* from, + socklen_t* fromlen, + struct sctp_sndrcvinfo* sinfo, + int msg_flags) const { - int rd_sz = sctp_recvmsg(sockfd, buf, nbytes, (struct sockaddr*)from, &fromlen, sinfo, &msg_flags); - if (rd_sz <= 0) { - perror("sctp read"); + if (fromlen != nullptr) { + *fromlen = sizeof(*from); } - return rd_sz; + return sctp_recvmsg(sockfd, buf, nbytes, (struct sockaddr*)from, fromlen, sinfo, &msg_flags); } -int sctp_socket::send(void* buf, ssize_t nbytes, uint32_t ppid, uint32_t stream_id) +int sctp_socket_t::send(void* buf, size_t nbytes, uint32_t ppid, uint32_t stream_id) const { return sctp_sendmsg( sockfd, buf, nbytes, (struct sockaddr*)&dest_addr, sizeof(dest_addr), htonl(ppid), 0, stream_id, 0, 0); } -sctp_socket::operator rx_sctp_socket_ref() +// Private Methods + +int sctp_socket_t::create_socket() { - return rx_sctp_socket_ref(this); + sockfd = socket(AF_INET, SOCK_SEQPACKET, IPPROTO_SCTP); + if (sockfd == -1) { + perror("Could not create SCTP socket\n"); + return -1; + } + return 0; } -// Private Methods +/*************************************************************** + * TCP Socket + **************************************************************/ -int sctp_socket::bind_addr(const char* bind_addr_str, int port) +void tcp_socket_t::reset() { - if (sockfd < 0) { - if (create_socket()) { - return -1; - } + reset_(); + dest_addr = {}; + if (connfd >= 0) { + connfd = -1; } +} - addr_in.sin_family = AF_INET; - if (inet_pton(AF_INET, bind_addr_str, &(addr_in.sin_addr)) != 1) { - perror("inet_pton"); +int tcp_socket_t::listen_addr(const char* bind_addr_str, int port) +{ + if (sockfd < 0 and bind_addr(bind_addr_str, port) != 0) { + reset(); return -1; } - addr_in.sin_port = (port != 0) ? htons(port) : 0; - if (bind(sockfd, (struct sockaddr*)&addr_in, sizeof(addr_in))) { - perror("bind()"); + + // Listen for connections + if (listen(sockfd, 1) != 0) { + perror("listen"); return -1; } + return 0; } -int sctp_socket::create_socket() +int tcp_socket_t::accept_connection() { - sockfd = socket(AF_INET, SOCK_SEQPACKET, IPPROTO_SCTP); - if (sockfd == -1) { - perror("Could not create SCTP socket\n"); + socklen_t clilen = sizeof(dest_addr); + connfd = accept(sockfd, (struct sockaddr*)&dest_addr, &clilen); + if (connfd < 0) { + perror("accept"); return -1; } return 0; } +int tcp_socket_t::connect_addr(const char* bind_addr_str, const char* dest_addr_str, int dest_port) +{ + if (sockfd < 0 and bind_addr(bind_addr_str, 0) != 0) { + return -1; + } + return connect_to(&dest_addr, dest_addr_str, dest_port); +} + +int tcp_socket_t::create_socket() +{ + sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd == -1) { + perror("Could not create TCP socket\n"); + return SRSLTE_ERROR; + } + return SRSLTE_SUCCESS; +} + +int tcp_socket_t::read(void* buf, size_t nbytes) const +{ + int n = ::read(connfd, buf, nbytes); + if (n == 0) { + return 0; + } + if (n == -1) { + perror("read"); + } + return n; +} + +int tcp_socket_t::send(void* buf, size_t nbytes) const +{ + // Loop until all bytes are sent + char* ptr = (char*)buf; + while (nbytes > 0) { + ssize_t i = ::send(connfd, ptr, nbytes, 0); + if (i < 1) { + perror("Error calling send()\n"); + return SRSLTE_ERROR; + } + ptr += i; + nbytes -= i; + } + return SRSLTE_SUCCESS; +} + /*************************************************************** * Rx Multisocket Handler **************************************************************/ @@ -222,39 +309,33 @@ rx_multisocket_handler::~rx_multisocket_handler() close(pipefd[0]); close(pipefd[1]); - // close all sockets - for (auto& handler_pair : active_sctp_sockets) { - if (close(handler_pair.first) == -1) { - rxSockError("Failed to close socket fd=%d\n", handler_pair.first); - } - } - rxSockDebug("closed.\n"); } -bool rx_multisocket_handler::register_sctp_socket(rx_sctp_socket_ref sock, callback_t recv_handler_) +bool rx_multisocket_handler::register_socket_(std::pair >&& elem) { + int fd = elem.first; std::lock_guard lock(socket_mutex); - if (sock.fd() < 0) { + if (fd < 0) { rxSockError("Provided SCTP socket must be already open\n"); return false; } - if (active_sctp_sockets.count(sock.fd()) > 0) { - rxSockError("Tried to register fd=%d, but this fd already exists\n", sock.fd()); + if (active_sockets.count(fd) > 0) { + rxSockError("Tried to register fd=%d, but this fd already exists\n", fd); return false; } - active_sctp_sockets.insert(std::make_pair(sock.fd(), sctp_handler_t{std::move(recv_handler_), sock})); + active_sockets.insert(std::move(elem)); // this unlocks the reading thread to add new connections ctrl_cmd_t msg; msg.cmd = ctrl_cmd_t::cmd_id_t::NEW_FD; - msg.new_fd = sock.fd(); + msg.new_fd = fd; if (write(pipefd[1], &msg, sizeof(msg)) != sizeof(msg)) { rxSockError("while writing to control pipe\n"); } - rxSockDebug("socket fd=%d has been registered.\n", sock.fd()); + rxSockDebug("socket fd=%d has been registered.\n", fd); return true; } @@ -284,16 +365,14 @@ void rx_multisocket_handler::run_thread() std::lock_guard lock(socket_mutex); - // call read callback for all SCTP connections - for (auto& handler_pair : active_sctp_sockets) { + // call read callback for all SCTP/TCP/UDP connections + for (auto& handler_pair : active_sockets) { if (not FD_ISSET(handler_pair.first, &read_fd_set)) { continue; } - handler_pair.second.callback(handler_pair.second.sctp_ptr); + handler_pair.second(); } - // TODO: For TCP and UDP - // handle ctrl messages if (FD_ISSET(pipefd[0], &read_fd_set)) { ctrl_cmd_t msg; diff --git a/lib/test/common/network_utils_test.cc b/lib/test/common/network_utils_test.cc index ee1a3f8e7..337fb1b4f 100644 --- a/lib/test/common/network_utils_test.cc +++ b/lib/test/common/network_utils_test.cc @@ -40,7 +40,7 @@ int test_socket_handler() int counter = 0; srslte::byte_buffer_pool* pool = srslte::byte_buffer_pool::get_instance(); - srslte::sctp_socket server_sock, client_sock; + srslte::sctp_socket_t server_sock, client_sock; srslte::rx_multisocket_handler sockhandler("RXSOCKETS", &log); TESTASSERT(server_sock.listen_addr("127.0.100.1", 36412) == 0); @@ -48,7 +48,7 @@ int test_socket_handler() TESTASSERT(client_sock.connect_addr("127.0.0.1", "127.0.100.1", 36412) == 0); - sockhandler.register_sctp_socket(server_sock, [pool, &log, &counter](srslte::rx_sctp_socket_ref sock) { + sockhandler.register_socket(server_sock, [pool, &log, &counter](const srslte::sctp_socket_t& sock) { srslte::unique_byte_buffer_t pdu = srslte::allocate_unique_buffer(*pool, true); int rd_sz = sock.read(pdu->msg, pdu->get_tailroom()); if (rd_sz > 0) {