diff --git a/lib/include/srslte/upper/pdcp_entity_base.h b/lib/include/srslte/upper/pdcp_entity_base.h index 9a8a4ddfc..18be0d6c4 100644 --- a/lib/include/srslte/upper/pdcp_entity_base.h +++ b/lib/include/srslte/upper/pdcp_entity_base.h @@ -80,6 +80,11 @@ public: // RLC interface void write_pdu(unique_byte_buffer_t pdu); + // COUNT, HFN and SN helpers + uint32_t HFN(uint32_t count); + uint32_t SN(uint32_t count); + uint32_t COUNT(uint32_t hfn, uint32_t sn); + protected: byte_buffer_pool* pool = byte_buffer_pool::get_instance(); srslte::log* log = nullptr; @@ -90,6 +95,9 @@ protected: bool do_integrity = false; bool do_encryption = false; + uint8_t sn_len = 0; + uint8_t sn_len_bytes = 0; + std::mutex mutex; uint8_t k_rrc_enc[32] = {}; @@ -110,5 +118,20 @@ protected: void cipher_decrypt(uint8_t* ct, uint32_t ct_len, uint32_t count, uint32_t bearer_id, uint32_t direction, uint8_t* msg); }; + +inline uint32_t pdcp_entity_base::HFN(uint32_t count) +{ + return (count >> sn_len); +} + +inline uint32_t pdcp_entity_base::SN(uint32_t count) +{ + return count & (0xFFFFFFFF >> sn_len); +} + +inline uint32_t pdcp_entity_base::COUNT(uint32_t hfn, uint32_t sn) +{ + return (hfn << sn_len) | sn; +} } // namespace srslte #endif // SRSLTE_PDCP_ENTITY_BASE_H diff --git a/lib/src/upper/pdcp_entity_nr.cc b/lib/src/upper/pdcp_entity_nr.cc index 0bf59d7d6..d0e9aa982 100644 --- a/lib/src/upper/pdcp_entity_nr.cc +++ b/lib/src/upper/pdcp_entity_nr.cc @@ -74,6 +74,7 @@ void pdcp_entity_nr::write_sdu(unique_byte_buffer_t sdu, bool blocking) (do_integrity) ? "true" : "false", (do_encryption) ? "true" : "false"); // TODO + } // RLC interface @@ -92,7 +93,33 @@ void pdcp_entity_nr::write_pdu(unique_byte_buffer_t pdu) return; } - // TODO -} + // Calculate RCVD_COUNT + uint32_t rcvd_sn = get_rcvd_sn(pdu); + uint32_t rcvd_hfn, rcvd_count; + if (rcvd_sn < SN(rx_deliv) - window_size) { + rcvd_hfn = HFN(rx_deliv) + 1; + } else if (rcvd_sn >= SN(rx_deliv) + window_size) { + rcvd_hfn = HFN(rx_deliv) - 1; + } else { + rcvd_hfn = HFN(rx_deliv); + } + rcvd_count = COUNT(rcvd_hfn, rcvd_sn); + + // Integrity check + bool is_valid = integrity_check(pdu); + if (!is_valid) { + return; // Invalid packet, drop. + } + + // Decripting + cipher_decript(pdu); + + // Check valid rcvd_count + if (rcvd_count < rx_deliv /*|| received_before (TODO)*/) { + return; // Invalid count, drop. + } + // + if(rcvd_count >= rx_next) + }