diff --git a/lib/include/srslte/phy/utils/mat.h b/lib/include/srslte/phy/utils/mat.h index 44af771ce..7181e86e5 100644 --- a/lib/include/srslte/phy/utils/mat.h +++ b/lib/include/srslte/phy/utils/mat.h @@ -24,6 +24,7 @@ #include "srslte/config.h" #include "srslte/phy/utils/simd.h" +#include /* Generic implementation for complex reciprocal */ SRSLTE_API cf_t srslte_mat_cf_recip_gen(cf_t a); @@ -203,4 +204,17 @@ static inline void srslte_mat_2x2_mmse_simd(simd_cf_t y0, } #endif /* SRSLTE_SIMD_CF_SIZE != 0 */ + +typedef struct { + uint32_t N; + cf_t* row_buffer; + cf_t* matrix; +} srslte_matrix_NxN_inv_t; + +SRSLTE_API int srslte_matrix_NxN_inv_init(srslte_matrix_NxN_inv_t* q, uint32_t N); + +SRSLTE_API void srslte_matrix_NxN_inv_run(srslte_matrix_NxN_inv_t* q, cf_t* in, cf_t* out); + +SRSLTE_API void srslte_matrix_NxN_inv_free(srslte_matrix_NxN_inv_t* q); + #endif /* SRSLTE_MAT_H */ diff --git a/lib/src/phy/utils/mat.c b/lib/src/phy/utils/mat.c index e3135d12b..ad73f2a3d 100644 --- a/lib/src/phy/utils/mat.c +++ b/lib/src/phy/utils/mat.c @@ -21,6 +21,7 @@ #include #include +#include #include "srslte/phy/utils/mat.h" @@ -241,6 +242,7 @@ inline void srslte_mat_2x2_mmse_sse(__m128 y0, #ifdef LV_HAVE_AVX #include +#include /* AVX implementation for complex reciprocal */ inline __m256 srslte_mat_cf_recip_avx(__m256 a) @@ -354,3 +356,217 @@ inline void srslte_mat_2x2_mmse_avx(__m256 y0, } #endif /* LV_HAVE_AVX */ + +int srslte_matrix_NxN_inv_init(srslte_matrix_NxN_inv_t* q, uint32_t N) +{ + int ret = SRSLTE_SUCCESS; + + if (q && N) { + // Set all to zero + bzero(q, sizeof(srslte_matrix_NxN_inv_t)); + + q->N = N; + + q->row_buffer = srslte_vec_malloc(sizeof(cf_t) * N * 2); + if (!q->row_buffer) { + perror("malloc"); + ret = SRSLTE_ERROR; + } + + if (!ret) { + q->matrix = srslte_vec_malloc(sizeof(cf_t) * N * N * 2); + if (!q->matrix) { + perror("malloc"); + ret = SRSLTE_ERROR; + } + } + } + + return ret; +} + +static inline void srslte_vec_sc_prod_ccc_simd_inline(const cf_t* x, const cf_t h, cf_t* z, const int len) +{ + int i = 0; + +#if SRSLTE_SIMD_F_SIZE + +#ifdef HAVE_NEON + i = srslte_vec_sc_prod_ccc_simd2(x, h, z, len); +#else + const simd_f_t hre = srslte_simd_f_set1(__real__ h); + const simd_f_t him = srslte_simd_f_set1(__imag__ h); + + for (; i < len - SRSLTE_SIMD_F_SIZE / 2 + 1; i += SRSLTE_SIMD_F_SIZE / 2) { + simd_f_t temp = srslte_simd_f_load((float*)&x[i]); + + simd_f_t m1 = srslte_simd_f_mul(hre, temp); + simd_f_t sw = srslte_simd_f_swap(temp); + simd_f_t m2 = srslte_simd_f_mul(him, sw); + simd_f_t r = srslte_simd_f_addsub(m1, m2); + srslte_simd_f_store((float*)&z[i], r); + } + +#endif +#endif + for (; i < len; i++) { + __real__ z[i] = __real__ x[i] * __real__ h - __imag__ x[i] * __imag__ h; + __imag__ z[i] = __real__ x[i] * __imag__ h + __imag__ x[i] * __real__ h; + } +} + +static inline void srslte_vec_sub_fff_simd_inline(const float* x, const float* y, float* z, const int len) +{ + int i = 0; + +#if SRSLTE_SIMD_F_SIZE + + for (; i < len - SRSLTE_SIMD_F_SIZE + 1; i += SRSLTE_SIMD_F_SIZE) { + simd_f_t a = srslte_simd_f_loadu(&x[i]); + simd_f_t b = srslte_simd_f_loadu(&y[i]); + + simd_f_t r = srslte_simd_f_sub(a, b); + + srslte_simd_f_storeu(&z[i], r); + } +#endif + + for (; i < len; i++) { + z[i] = x[i] - y[i]; + } +} + +static inline void srslte_vec_sub_ccc_simd_inline(const cf_t* x, const cf_t* y, cf_t* z, const int len) +{ + srslte_vec_sub_fff_simd_inline((float*)x, (float*)y, (float*)z, 2 * len); +} + +static inline cf_t reciprocal(cf_t x) +{ + cf_t y; + + float mod = __real__ x * __real__ x + __imag__ x * __imag__ x; + __real__ y = (__real__ x) / mod; + __imag__ y = -(__imag__ x) / mod; + + return y; +} + +void srslte_matrix_NxN_inv_run(srslte_matrix_NxN_inv_t* q, cf_t* in, cf_t* out) +{ + if (q && in && out) { + int N = q->N; + cf_t* tmp_src_ptr; + cf_t* tmp_dst_ptr; + + // 0) Copy the input vector in the matrix + tmp_src_ptr = in; + tmp_dst_ptr = q->matrix; + for (int i = 0; i < N; i++) { + // Populate first half with input matrix + memcpy(tmp_dst_ptr, tmp_src_ptr, sizeof(cf_t) * N); + tmp_src_ptr += N; + tmp_dst_ptr += N; + + // Populate second half with identity matrix + bzero(tmp_dst_ptr, sizeof(cf_t) * N); + tmp_dst_ptr[i] = 1.0f; + tmp_dst_ptr += N; + } + + // 1) Forward elimination + for (int i = 0; i < N - 1; i++) { + tmp_src_ptr = &q->matrix[N * 2 * (N - 1 - i)]; + cf_t b = tmp_src_ptr[N - 1 - i]; + srslte_vec_sc_prod_ccc_simd_inline(tmp_src_ptr, reciprocal(b), tmp_src_ptr, 2 * N); + + for (int j = 0; j < N - i - 1; j++) { + cf_t a = q->matrix[N * (2 * j + 1) - 1 - i]; + + if (a != 0.0f && b != 0.0f) { + tmp_dst_ptr = &q->matrix[N * 2 * j]; + srslte_vec_sc_prod_ccc_simd_inline(tmp_dst_ptr, reciprocal(a), tmp_dst_ptr, 2 * N); + srslte_vec_sub_ccc_simd_inline(tmp_dst_ptr, tmp_src_ptr, tmp_dst_ptr, 2 * N); + } + } + } + srslte_vec_sc_prod_ccc_simd_inline(q->matrix, reciprocal(q->matrix[0]), q->matrix, 2 * N); + + // 2) Backward elimination + for (int i = 0; i < N - 1; i++) { + tmp_src_ptr = &q->matrix[N * 2 * i]; + cf_t b = tmp_src_ptr[i]; + srslte_vec_sc_prod_ccc_simd_inline(tmp_src_ptr, reciprocal(b), tmp_src_ptr, 2 * N); + + for (int j = N - 1; j > i; j--) { + cf_t a = q->matrix[N * 2 * j + i]; + + tmp_dst_ptr = &q->matrix[N * 2 * j]; + srslte_vec_sc_prod_ccc_simd_inline(tmp_dst_ptr, reciprocal(a), tmp_dst_ptr, 2 * N); + srslte_vec_sub_ccc_simd_inline(tmp_dst_ptr, tmp_src_ptr, tmp_dst_ptr, 2 * N); + } + } + srslte_vec_sc_prod_ccc_simd_inline(&q->matrix[2 * N * (N - 1)], + reciprocal(q->matrix[2 * N * (N - 1) + N - 1]), + &q->matrix[2 * N * (N - 1)], + 2 * N); + + // 4) Copy result + tmp_src_ptr = &q->matrix[N]; + tmp_dst_ptr = out; + for (int i = 0; i < N; i++) { + memcpy(tmp_dst_ptr, tmp_src_ptr, sizeof(cf_t) * N); + tmp_src_ptr += 2 * N; + tmp_dst_ptr += N; + } + +#if 0 + printf("tmp = [...\n"); + for (int i = 0; i < N; i++) { + printf("\t"); + for (int j = 0; j < 2 * N; j++) { + printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ q->matrix[2 * N * i + j], __imag__ q->matrix[2 * N * i + j]); + } + printf("; ...\n"); + } + printf("];\n"); + + printf("in = [...\n"); + for (int i = 0; i < N; i++) { + printf("\t"); + for (int j = 0; j < N; j++) { + printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ in[N * i + j], __imag__ in[N * i + j]); + } + printf("; ...\n"); + } + printf("];\n"); + + printf("out = [...\n"); + for (int i = 0; i < N; i++) { + printf("\t"); + for (int j = 0; j < N; j++) { + printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ out[N * i + j], __imag__ out[N * i + j]); + } + printf("; ...\n"); + } + printf("];\n"); +#endif + } +} + +void srslte_matrix_NxN_inv_free(srslte_matrix_NxN_inv_t* q) +{ + if (q) { + + if (q->matrix) { + free(q->matrix); + } + + if (q->row_buffer) { + free(q->row_buffer); + } + + // Default all to zero + bzero(q, sizeof(srslte_matrix_NxN_inv_t)); + } +} diff --git a/lib/src/phy/utils/test/mat_test.c b/lib/src/phy/utils/test/mat_test.c index 682b57269..4b06997e1 100644 --- a/lib/src/phy/utils/test/mat_test.c +++ b/lib/src/phy/utils/test/mat_test.c @@ -30,7 +30,7 @@ #include "srslte/phy/utils/mat.h" #include "srslte/phy/utils/vector.h" #include "srslte/phy/utils/vector_simd.h" - +static bool inverter = false; static bool zf_solver = false; static bool mmse_solver = false; static bool verbose = false; @@ -82,8 +82,11 @@ void usage(char* prog) void parse_args(int argc, char** argv) { int opt; - while ((opt = getopt(argc, argv, "mzvh")) != -1) { + while ((opt = getopt(argc, argv, "imzvh")) != -1) { switch (opt) { + case 'i': + inverter = true; + break; case 'm': mmse_solver = true; break; @@ -267,6 +270,29 @@ static bool test_vec_dot_prod_ccc(void) return (cabsf(res - gold) < MAXIMUM_ERROR); } +bool test_matrix_inv(void) +{ + const uint32_t N = 64; + __attribute__((aligned(256))) cf_t x[N * N]; + __attribute__((aligned(256))) cf_t y[N * N]; + + srslte_matrix_NxN_inv_t matrix_nxn_inv = {}; + + srslte_matrix_NxN_inv_init(&matrix_nxn_inv, N); + + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + x[i * N + j] = srslte_random_uniform_complex_dist(random_gen, -1.0f, +1.0f); + } + } + + srslte_matrix_NxN_inv_run(&matrix_nxn_inv, x, y); + + srslte_matrix_NxN_inv_free(&matrix_nxn_inv); + + return true; +} + int main(int argc, char** argv) { bool passed = true; @@ -292,6 +318,10 @@ int main(int argc, char** argv) #endif /* SRSLTE_SIMD_CF_SIZE != 0*/ } + if (inverter) { + RUN_TEST(test_matrix_inv); + } + RUN_TEST(test_vec_dot_prod_ccc); printf("%s!\n", (passed) ? "Ok" : "Failed");