Phy utils: Added initial matrix inverse

master
Xavier Arteaga 5 years ago committed by Xavier Arteaga
parent 3f6eca1aea
commit fff96d9aac

@ -24,6 +24,7 @@
#include "srslte/config.h" #include "srslte/config.h"
#include "srslte/phy/utils/simd.h" #include "srslte/phy/utils/simd.h"
#include <inttypes.h>
/* Generic implementation for complex reciprocal */ /* Generic implementation for complex reciprocal */
SRSLTE_API cf_t srslte_mat_cf_recip_gen(cf_t a); SRSLTE_API cf_t srslte_mat_cf_recip_gen(cf_t a);
@ -203,4 +204,17 @@ static inline void srslte_mat_2x2_mmse_simd(simd_cf_t y0,
} }
#endif /* SRSLTE_SIMD_CF_SIZE != 0 */ #endif /* SRSLTE_SIMD_CF_SIZE != 0 */
typedef struct {
uint32_t N;
cf_t* row_buffer;
cf_t* matrix;
} srslte_matrix_NxN_inv_t;
SRSLTE_API int srslte_matrix_NxN_inv_init(srslte_matrix_NxN_inv_t* q, uint32_t N);
SRSLTE_API void srslte_matrix_NxN_inv_run(srslte_matrix_NxN_inv_t* q, cf_t* in, cf_t* out);
SRSLTE_API void srslte_matrix_NxN_inv_free(srslte_matrix_NxN_inv_t* q);
#endif /* SRSLTE_MAT_H */ #endif /* SRSLTE_MAT_H */

@ -21,6 +21,7 @@
#include <complex.h> #include <complex.h>
#include <math.h> #include <math.h>
#include <memory.h>
#include "srslte/phy/utils/mat.h" #include "srslte/phy/utils/mat.h"
@ -241,6 +242,7 @@ inline void srslte_mat_2x2_mmse_sse(__m128 y0,
#ifdef LV_HAVE_AVX #ifdef LV_HAVE_AVX
#include <immintrin.h> #include <immintrin.h>
#include <srslte/phy/utils/vector.h>
/* AVX implementation for complex reciprocal */ /* AVX implementation for complex reciprocal */
inline __m256 srslte_mat_cf_recip_avx(__m256 a) inline __m256 srslte_mat_cf_recip_avx(__m256 a)
@ -354,3 +356,217 @@ inline void srslte_mat_2x2_mmse_avx(__m256 y0,
} }
#endif /* LV_HAVE_AVX */ #endif /* LV_HAVE_AVX */
int srslte_matrix_NxN_inv_init(srslte_matrix_NxN_inv_t* q, uint32_t N)
{
int ret = SRSLTE_SUCCESS;
if (q && N) {
// Set all to zero
bzero(q, sizeof(srslte_matrix_NxN_inv_t));
q->N = N;
q->row_buffer = srslte_vec_malloc(sizeof(cf_t) * N * 2);
if (!q->row_buffer) {
perror("malloc");
ret = SRSLTE_ERROR;
}
if (!ret) {
q->matrix = srslte_vec_malloc(sizeof(cf_t) * N * N * 2);
if (!q->matrix) {
perror("malloc");
ret = SRSLTE_ERROR;
}
}
}
return ret;
}
static inline void srslte_vec_sc_prod_ccc_simd_inline(const cf_t* x, const cf_t h, cf_t* z, const int len)
{
int i = 0;
#if SRSLTE_SIMD_F_SIZE
#ifdef HAVE_NEON
i = srslte_vec_sc_prod_ccc_simd2(x, h, z, len);
#else
const simd_f_t hre = srslte_simd_f_set1(__real__ h);
const simd_f_t him = srslte_simd_f_set1(__imag__ h);
for (; i < len - SRSLTE_SIMD_F_SIZE / 2 + 1; i += SRSLTE_SIMD_F_SIZE / 2) {
simd_f_t temp = srslte_simd_f_load((float*)&x[i]);
simd_f_t m1 = srslte_simd_f_mul(hre, temp);
simd_f_t sw = srslte_simd_f_swap(temp);
simd_f_t m2 = srslte_simd_f_mul(him, sw);
simd_f_t r = srslte_simd_f_addsub(m1, m2);
srslte_simd_f_store((float*)&z[i], r);
}
#endif
#endif
for (; i < len; i++) {
__real__ z[i] = __real__ x[i] * __real__ h - __imag__ x[i] * __imag__ h;
__imag__ z[i] = __real__ x[i] * __imag__ h + __imag__ x[i] * __real__ h;
}
}
static inline void srslte_vec_sub_fff_simd_inline(const float* x, const float* y, float* z, const int len)
{
int i = 0;
#if SRSLTE_SIMD_F_SIZE
for (; i < len - SRSLTE_SIMD_F_SIZE + 1; i += SRSLTE_SIMD_F_SIZE) {
simd_f_t a = srslte_simd_f_loadu(&x[i]);
simd_f_t b = srslte_simd_f_loadu(&y[i]);
simd_f_t r = srslte_simd_f_sub(a, b);
srslte_simd_f_storeu(&z[i], r);
}
#endif
for (; i < len; i++) {
z[i] = x[i] - y[i];
}
}
static inline void srslte_vec_sub_ccc_simd_inline(const cf_t* x, const cf_t* y, cf_t* z, const int len)
{
srslte_vec_sub_fff_simd_inline((float*)x, (float*)y, (float*)z, 2 * len);
}
static inline cf_t reciprocal(cf_t x)
{
cf_t y;
float mod = __real__ x * __real__ x + __imag__ x * __imag__ x;
__real__ y = (__real__ x) / mod;
__imag__ y = -(__imag__ x) / mod;
return y;
}
void srslte_matrix_NxN_inv_run(srslte_matrix_NxN_inv_t* q, cf_t* in, cf_t* out)
{
if (q && in && out) {
int N = q->N;
cf_t* tmp_src_ptr;
cf_t* tmp_dst_ptr;
// 0) Copy the input vector in the matrix
tmp_src_ptr = in;
tmp_dst_ptr = q->matrix;
for (int i = 0; i < N; i++) {
// Populate first half with input matrix
memcpy(tmp_dst_ptr, tmp_src_ptr, sizeof(cf_t) * N);
tmp_src_ptr += N;
tmp_dst_ptr += N;
// Populate second half with identity matrix
bzero(tmp_dst_ptr, sizeof(cf_t) * N);
tmp_dst_ptr[i] = 1.0f;
tmp_dst_ptr += N;
}
// 1) Forward elimination
for (int i = 0; i < N - 1; i++) {
tmp_src_ptr = &q->matrix[N * 2 * (N - 1 - i)];
cf_t b = tmp_src_ptr[N - 1 - i];
srslte_vec_sc_prod_ccc_simd_inline(tmp_src_ptr, reciprocal(b), tmp_src_ptr, 2 * N);
for (int j = 0; j < N - i - 1; j++) {
cf_t a = q->matrix[N * (2 * j + 1) - 1 - i];
if (a != 0.0f && b != 0.0f) {
tmp_dst_ptr = &q->matrix[N * 2 * j];
srslte_vec_sc_prod_ccc_simd_inline(tmp_dst_ptr, reciprocal(a), tmp_dst_ptr, 2 * N);
srslte_vec_sub_ccc_simd_inline(tmp_dst_ptr, tmp_src_ptr, tmp_dst_ptr, 2 * N);
}
}
}
srslte_vec_sc_prod_ccc_simd_inline(q->matrix, reciprocal(q->matrix[0]), q->matrix, 2 * N);
// 2) Backward elimination
for (int i = 0; i < N - 1; i++) {
tmp_src_ptr = &q->matrix[N * 2 * i];
cf_t b = tmp_src_ptr[i];
srslte_vec_sc_prod_ccc_simd_inline(tmp_src_ptr, reciprocal(b), tmp_src_ptr, 2 * N);
for (int j = N - 1; j > i; j--) {
cf_t a = q->matrix[N * 2 * j + i];
tmp_dst_ptr = &q->matrix[N * 2 * j];
srslte_vec_sc_prod_ccc_simd_inline(tmp_dst_ptr, reciprocal(a), tmp_dst_ptr, 2 * N);
srslte_vec_sub_ccc_simd_inline(tmp_dst_ptr, tmp_src_ptr, tmp_dst_ptr, 2 * N);
}
}
srslte_vec_sc_prod_ccc_simd_inline(&q->matrix[2 * N * (N - 1)],
reciprocal(q->matrix[2 * N * (N - 1) + N - 1]),
&q->matrix[2 * N * (N - 1)],
2 * N);
// 4) Copy result
tmp_src_ptr = &q->matrix[N];
tmp_dst_ptr = out;
for (int i = 0; i < N; i++) {
memcpy(tmp_dst_ptr, tmp_src_ptr, sizeof(cf_t) * N);
tmp_src_ptr += 2 * N;
tmp_dst_ptr += N;
}
#if 0
printf("tmp = [...\n");
for (int i = 0; i < N; i++) {
printf("\t");
for (int j = 0; j < 2 * N; j++) {
printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ q->matrix[2 * N * i + j], __imag__ q->matrix[2 * N * i + j]);
}
printf("; ...\n");
}
printf("];\n");
printf("in = [...\n");
for (int i = 0; i < N; i++) {
printf("\t");
for (int j = 0; j < N; j++) {
printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ in[N * i + j], __imag__ in[N * i + j]);
}
printf("; ...\n");
}
printf("];\n");
printf("out = [...\n");
for (int i = 0; i < N; i++) {
printf("\t");
for (int j = 0; j < N; j++) {
printf("%c %+.3f%+.3fi", j ==0 ? ' ' : ',', __real__ out[N * i + j], __imag__ out[N * i + j]);
}
printf("; ...\n");
}
printf("];\n");
#endif
}
}
void srslte_matrix_NxN_inv_free(srslte_matrix_NxN_inv_t* q)
{
if (q) {
if (q->matrix) {
free(q->matrix);
}
if (q->row_buffer) {
free(q->row_buffer);
}
// Default all to zero
bzero(q, sizeof(srslte_matrix_NxN_inv_t));
}
}

@ -30,7 +30,7 @@
#include "srslte/phy/utils/mat.h" #include "srslte/phy/utils/mat.h"
#include "srslte/phy/utils/vector.h" #include "srslte/phy/utils/vector.h"
#include "srslte/phy/utils/vector_simd.h" #include "srslte/phy/utils/vector_simd.h"
static bool inverter = false;
static bool zf_solver = false; static bool zf_solver = false;
static bool mmse_solver = false; static bool mmse_solver = false;
static bool verbose = false; static bool verbose = false;
@ -82,8 +82,11 @@ void usage(char* prog)
void parse_args(int argc, char** argv) void parse_args(int argc, char** argv)
{ {
int opt; int opt;
while ((opt = getopt(argc, argv, "mzvh")) != -1) { while ((opt = getopt(argc, argv, "imzvh")) != -1) {
switch (opt) { switch (opt) {
case 'i':
inverter = true;
break;
case 'm': case 'm':
mmse_solver = true; mmse_solver = true;
break; break;
@ -267,6 +270,29 @@ static bool test_vec_dot_prod_ccc(void)
return (cabsf(res - gold) < MAXIMUM_ERROR); return (cabsf(res - gold) < MAXIMUM_ERROR);
} }
bool test_matrix_inv(void)
{
const uint32_t N = 64;
__attribute__((aligned(256))) cf_t x[N * N];
__attribute__((aligned(256))) cf_t y[N * N];
srslte_matrix_NxN_inv_t matrix_nxn_inv = {};
srslte_matrix_NxN_inv_init(&matrix_nxn_inv, N);
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
x[i * N + j] = srslte_random_uniform_complex_dist(random_gen, -1.0f, +1.0f);
}
}
srslte_matrix_NxN_inv_run(&matrix_nxn_inv, x, y);
srslte_matrix_NxN_inv_free(&matrix_nxn_inv);
return true;
}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {
bool passed = true; bool passed = true;
@ -292,6 +318,10 @@ int main(int argc, char** argv)
#endif /* SRSLTE_SIMD_CF_SIZE != 0*/ #endif /* SRSLTE_SIMD_CF_SIZE != 0*/
} }
if (inverter) {
RUN_TEST(test_matrix_inv);
}
RUN_TEST(test_vec_dot_prod_ccc); RUN_TEST(test_vec_dot_prod_ccc);
printf("%s!\n", (passed) ? "Ok" : "Failed"); printf("%s!\n", (passed) ? "Ok" : "Failed");

Loading…
Cancel
Save