diff --git a/matlab/tests/diversity_decode_test.m b/matlab/tests/diversity_decode_test.m index 536feebdc..75c29c2e3 100644 --- a/matlab/tests/diversity_decode_test.m +++ b/matlab/tests/diversity_decode_test.m @@ -1,8 +1,9 @@ clear -addpath('../../build/srslte/lib/mimo/test') +addpath('../../debug/srslte/lib/mimo/test') -enb = lteRMCDL('R.10'); +%enb = lteRMCDL('R.10'); % 2-ports +enb = lteRMCDL('R.0'); % 1-ports cec = struct('FreqWindow',9,'TimeWindow',9,'InterpType','cubic'); cec.PilotAverage = 'UserDefined'; @@ -10,7 +11,7 @@ cec.InterpWinSize = 1; cec.InterpWindow = 'Causal'; cfg.Seed = 1; % Random channel seed -cfg.NRxAnts = 2; % 1 receive antenna +cfg.NRxAnts = 1; % 1 receive antenna cfg.DelayProfile = 'ETU'; % EVA delay spread cfg.DopplerFreq = 100; % 120Hz Doppler frequency cfg.MIMOCorrelation = 'Low'; % Low (no) MIMO correlation @@ -22,10 +23,9 @@ cfg.NormalizePathGains = 'On'; % Normalize delay profile power cfg.NormalizeTxAnts = 'On'; % Normalize for transmit antennas [txWaveform, ~, info] = lteRMCDLTool(enb,[1;0;0;1]); -n = length(txWaveform); cfg.SamplingRate = info.SamplingRate; -txWaveform = txWaveform+complex(randn(n,2),randn(n,2))*1e-3; +txWaveform = txWaveform+complex(randn(size(txWaveform)),randn(size(txWaveform)))*1e-3; rxWaveform = lteFadingChannel(cfg,txWaveform); @@ -35,16 +35,36 @@ rxGrid = lteOFDMDemodulate(enb,rxWaveform); s=size(h); p=s(1); -Nt=s(4); -Nr=s(3); +n=s(2); +if (length(s)>2) + Nr=s(3); +else + Nr=1; +end +if (length(s)>3) + Nt=s(4); +else + Nt=1; +end -rx=reshape(rxGrid(:,1,:),p,Nr); -hp=reshape(h(:,1,:,:),p,Nr,Nt); +if (Nr > 1) + rx=reshape(rxGrid,p,n,Nr); + hp=reshape(h,p,n,Nr,Nt); +else + rx=rxGrid; + hp=h; +end -output_mat = lteTransmitDiversityDecode(rx, hp); -output_srs = srslte_diversitydecode(rx, hp); +if (Nt > 1) + output_mat = lteTransmitDiversityDecode(rx, hp); +else + output_mat = lteEqualizeMMSE(rx, hp, n0); +end +output_srs = srslte_diversitydecode(rx, hp, n0); -plot(abs(output_mat-output_srs)) -mean(abs(output_mat-output_srs).^2) +plot(abs(output_mat(:)-output_srs(:))) +mean(abs(output_mat(:)-output_srs(:)).^2) +t=1:10; +plot(t,real(output_mat(t)),t,real(output_srs(t))) diff --git a/srslte/lib/mimo/precoding.c b/srslte/lib/mimo/precoding.c index 3b1d9a649..78cdaaecc 100644 --- a/srslte/lib/mimo/precoding.c +++ b/srslte/lib/mimo/precoding.c @@ -71,9 +71,9 @@ int srslte_predecoding_single_sse(cf_t *y[SRSLTE_MAX_RXANT], cf_t *h[SRSLTE_MAX_ __m128 noise = _mm_set1_ps(noise_estimate); __m128 h1Val1, h2Val1, y1Val1, y2Val1; __m128 h1Val2, h2Val2, y1Val2, y2Val2; - __m128 h12square1, h1square1, h2square1, h1conj1, h2conj1, x1Val1, x2Val1; - __m128 h12square2, h1square2, h2square2, h1conj2, h2conj2, x1Val2, x2Val2; - + __m128 hsquare, h1square, h2square, h1conj1, h2conj1, x1Val1, x2Val1; + __m128 hsquare2, h1conj2, h2conj2, x1Val2, x2Val2; + for (int i=0;i 0) { - h12square1 = _mm_add_ps(h12square1, noise); + hsquare = _mm_add_ps(hsquare, noise); } - h1square1 = _mm_shuffle_ps(h12square1, h12square1, _MM_SHUFFLE(1, 1, 0, 0)); - h2square1 = _mm_shuffle_ps(h12square1, h12square1, _MM_SHUFFLE(3, 3, 2, 2)); - - if (nof_rxant == 2) { - h1square2 = _mm_shuffle_ps(h12square2, h12square2, _MM_SHUFFLE(1, 1, 0, 0)); - h2square2 = _mm_shuffle_ps(h12square2, h12square2, _MM_SHUFFLE(3, 3, 2, 2)); - - h1square1 = _mm_add_ps(h1square1, h1square2); - h2square1 = _mm_add_ps(h2square1, h2square2); - } + h1square = _mm_shuffle_ps(hsquare, hsquare, _MM_SHUFFLE(1, 1, 0, 0)); + h2square = _mm_shuffle_ps(hsquare, hsquare, _MM_SHUFFLE(3, 3, 2, 2)); /* Conjugate channel */ h1conj1 = _mm_xor_ps(h1Val1, conjugator); @@ -122,17 +114,26 @@ int srslte_predecoding_single_sse(cf_t *y[SRSLTE_MAX_RXANT], cf_t *h[SRSLTE_MAX_ if (nof_rxant == 2) { x1Val2 = PROD(y1Val2, h1conj2); - x2Val2 = PROD(y2Val2, h2conj2); + x2Val2 = PROD(y2Val2, h2conj2); + x1Val1 = _mm_add_ps(x1Val1, x1Val2); + x2Val1 = _mm_add_ps(x2Val1, x2Val2); } - x1Val1 = _mm_div_ps(x1Val1, h1square1); - x2Val1 = _mm_div_ps(x2Val1, h2square1); + x1Val1 = _mm_div_ps(x1Val1, h1square); + x2Val1 = _mm_div_ps(x2Val1, h2square); + + _mm_store_ps(xPtr, x1Val1); xPtr+=4; + _mm_store_ps(xPtr, x2Val1); xPtr+=4; - _mm_store_ps(xPtr, x1Val); xPtr+=4; - _mm_store_ps(xPtr, x2Val); xPtr+=4; } for (int i=8*(nof_symbols/8);i 32) { - return srslte_predecoding_single_avx(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_avx(y, h, x, nof_rxant, nof_symbols, noise_estimate); } else { - return srslte_predecoding_single_gen(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); } #else #ifdef LV_HAVE_SSE if (nof_symbols > 32) { - return srslte_predecoding_single_sse(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_sse(y, h, x, nof_rxant, nof_symbols, noise_estimate); } else { - return srslte_predecoding_single_gen(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); } #else return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); @@ -240,16 +249,16 @@ int srslte_predecoding_single(cf_t *y_, cf_t *h_, cf_t *x, int nof_symbols, floa int srslte_predecoding_single_multi(cf_t *y[SRSLTE_MAX_RXANT], cf_t *h[SRSLTE_MAX_RXANT], cf_t *x, int nof_rxant, int nof_symbols, float noise_estimate) { #ifdef LV_HAVE_AVX if (nof_symbols > 32) { - return srslte_predecoding_single_avx(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_avx(y, h, x, nof_rxant, nof_symbols, noise_estimate); } else { - return srslte_predecoding_single_gen(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); } #else #ifdef LV_HAVE_SSE if (nof_symbols > 32) { - return srslte_predecoding_single_sse(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_sse(y, h, x, nof_rxant, nof_symbols, noise_estimate); } else { - return srslte_predecoding_single_gen(y, h, x, nof_symbols, noise_estimate); + return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); } #else return srslte_predecoding_single_gen(y, h, x, nof_rxant, nof_symbols, noise_estimate); diff --git a/srslte/lib/mimo/test/diversitydecode_mex.c b/srslte/lib/mimo/test/diversitydecode_mex.c index bbc99416d..7c020c6b8 100644 --- a/srslte/lib/mimo/test/diversitydecode_mex.c +++ b/srslte/lib/mimo/test/diversitydecode_mex.c @@ -33,6 +33,7 @@ #define INPUT prhs[0] #define HEST prhs[1] +#define NEST prhs[2] #define NOF_INPUTS 2 @@ -56,44 +57,70 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) } // Read input symbols - nof_symbols = mexutils_read_cf(INPUT, &input); - if (nof_symbols < 0) { + if (mexutils_read_cf(INPUT, &input) < 0) { mexErrMsgTxt("Error reading input\n"); return; } + uint32_t nof_tx_ports = 1; + uint32_t nof_rx_ants = 1; + const mwSize *dims = mxGetDimensions(INPUT); + mwSize ndims = mxGetNumberOfDimensions(INPUT); + nof_symbols = dims[0]*dims[1]; + + if (ndims >= 3) { + nof_rx_ants = dims[2]; + } + if (ndims >= 4) { + nof_tx_ports = dims[3]; + } + // Read channel estimates - uint32_t nof_symbols2 = mexutils_read_cf(HEST, &hest); - if (nof_symbols < 0) { + if (mexutils_read_cf(HEST, &hest) < 0) { mexErrMsgTxt("Error reading hest\n"); return; } - if ((nof_symbols2 % nof_symbols) != 0) { - mexErrMsgTxt("Hest size must be multiple of input size\n"); - return; + + // Read noise estimate + float noise_estimate = 0; + if (nrhs >= NOF_INPUTS) { + noise_estimate = mxGetScalar(NEST); } - // Calculate number of ports - uint32_t nof_ports = nof_symbols2/nof_symbols; - cf_t *x[8]; - cf_t *h[4]; + cf_t *x[SRSLTE_MAX_LAYERS]; + cf_t *h[SRSLTE_MAX_PORTS][SRSLTE_MAX_RXANT]; + cf_t *y[SRSLTE_MAX_RXANT]; + + for (int i=0;i 1) { + //srslte_predecoding_diversity(input, h, x, nof_tx_ports, nof_symbols); + //srslte_layerdemap_diversity(x, output, nof_tx_ports, nof_symbols / nof_tx_ports); + } else { + srslte_predecoding_single_multi(y, h[0], output, nof_rx_ants, nof_symbols, noise_estimate); } - srslte_predecoding_diversity(input, h, x, nof_ports, nof_symbols); - srslte_layerdemap_diversity(x, output, nof_ports, nof_symbols / nof_ports); - if (nlhs >= 1) { mexutils_write_cf(output, &plhs[0], nof_symbols, 1); @@ -105,7 +132,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) if (output) { free(output); } - for (i=0;i<8;i++) { + for (int i=0;i