Fix LDPC compilation for GCC 10

master
Xavier Arteaga 4 years ago committed by Xavier Arteaga
parent 576a923a4c
commit 93752fb2c4

@ -54,8 +54,8 @@ static const int8_t infinity7 = (1U << 6U) - 1;
* \brief Represents a node of the base factor graph.
*/
typedef union bg_node_avx512_t {
int8_t c[SRSLTE_AVX512_B_SIZE]; /*!< Each base node may contain up to \ref SRSLTE_AVX512_B_SIZE lifted nodes. */
__m512i v; /*!< All the lifted nodes of the current base node as a 512-bit line. */
int8_t* c; /*!< Each base node may contain up to \ref SRSLTE_AVX512_B_SIZE lifted nodes. */
__m512i* v; /*!< All the lifted nodes of the current base node as a 512-bit line. */
} bg_node_avx512_t;
/*!
@ -64,7 +64,7 @@ typedef union bg_node_avx512_t {
struct ldpc_regs_c_avx512 {
__m512i scaling_fctr; /*!< \brief Scaling factor for the normalized min-sum decoding algorithm. */
bg_node_avx512_t* soft_bits; /*!< \brief A-posteriori log-likelihood ratios. */
bg_node_avx512_t soft_bits; /*!< \brief A-posteriori log-likelihood ratios. */
__m512i* check_to_var; /*!< \brief Check-to-variable messages. */
__m512i* var_to_check; /*!< \brief Variable-to-check messages. */
__m512i* var_to_check_to_free; /*!< \brief the Variable-to-check messages with one extra _mm512 allocated space. */
@ -123,43 +123,34 @@ void* create_ldpc_dec_c_avx512(uint8_t bgN, uint8_t bgM, uint16_t ls, float scal
uint8_t bgK = bgN - bgM;
uint16_t hrr = bgK + 4;
if ((vp = srslte_vec_malloc(sizeof(struct ldpc_regs_c_avx512))) == NULL) {
if ((vp = SRSLTE_MEM_ALLOC(struct ldpc_regs_c_avx512, 1)) == NULL) {
return NULL;
}
SRSLTE_MEM_ZERO(vp, struct ldpc_regs_c_avx512, 1);
if ((vp->soft_bits = srslte_vec_malloc(bgN * sizeof(bg_node_avx512_t))) == NULL) {
free(vp);
if ((vp->soft_bits.v = SRSLTE_MEM_ALLOC(__m512i, bgN)) == NULL) {
delete_ldpc_dec_c_avx512(vp);
return NULL;
}
if ((vp->check_to_var = srslte_vec_malloc((hrr + 1) * bgM * sizeof(__m512i))) == NULL) {
free(vp->soft_bits);
free(vp);
if ((vp->check_to_var = SRSLTE_MEM_ALLOC(__m512i, (hrr + 1) * bgM)) == NULL) {
delete_ldpc_dec_c_avx512(vp);
return NULL;
}
if ((vp->var_to_check_to_free = srslte_vec_malloc(((hrr + 1) + 2) * sizeof(__m512i))) == NULL) {
free(vp->check_to_var);
free(vp->soft_bits);
free(vp);
if ((vp->var_to_check_to_free = SRSLTE_MEM_ALLOC(__m512i, (hrr + 1) + 2)) == NULL) {
delete_ldpc_dec_c_avx512(vp);
return NULL;
}
vp->var_to_check = &vp->var_to_check_to_free[1];
if ((vp->rotated_v2c = srslte_vec_malloc((hrr + 1) * sizeof(__m512i))) == NULL) {
free(vp->var_to_check_to_free);
free(vp->check_to_var);
free(vp->soft_bits);
free(vp);
if ((vp->rotated_v2c = SRSLTE_MEM_ALLOC(__m512i, hrr + 1)) == NULL) {
delete_ldpc_dec_c_avx512(vp);
return NULL;
}
if ((vp->this_c2v_epi8_to_free = srslte_vec_malloc((1 + 2) * sizeof(__m512i))) == NULL) {
free(vp->rotated_v2c);
free(vp->var_to_check_to_free);
free(vp->check_to_var);
free(vp->soft_bits);
free(vp);
if ((vp->this_c2v_epi8_to_free = SRSLTE_MEM_ALLOC(__m512i, 1 + 2)) == NULL) {
delete_ldpc_dec_c_avx512(vp);
return NULL;
}
vp->this_c2v_epi8 =
@ -181,14 +172,25 @@ void delete_ldpc_dec_c_avx512(void* p)
{
struct ldpc_regs_c_avx512* vp = p;
if (vp != NULL) {
if (vp == NULL) {
return;
}
if (vp->this_c2v_epi8_to_free) {
free(vp->this_c2v_epi8_to_free);
}
if (vp->rotated_v2c) {
free(vp->rotated_v2c);
}
if (vp->var_to_check_to_free) {
free(vp->var_to_check_to_free);
}
if (vp->check_to_var) {
free(vp->check_to_var);
free(vp->soft_bits);
free(vp);
}
if (vp->soft_bits.v) {
free(vp->soft_bits.v);
}
free(vp);
}
int init_ldpc_dec_c_avx512(void* p, const int8_t* llrs, uint16_t ls)
@ -204,19 +206,19 @@ int init_ldpc_dec_c_avx512(void* p, const int8_t* llrs, uint16_t ls)
// First 2 punctured bits
int ini = SRSLTE_AVX512_B_SIZE + SRSLTE_AVX512_B_SIZE;
bzero(vp->soft_bits->c, ini);
srslte_vec_i8_zero(vp->soft_bits.c, ini);
for (i = 0; i < vp->finalN; i = i + ls) {
for (k = 0; k < ls; k++) {
vp->soft_bits->c[ini + k] = llrs[i + k];
vp->soft_bits.c[ini + k] = llrs[i + k];
}
// this might be removed
bzero(&vp->soft_bits->c[ini + ls], (SRSLTE_AVX512_B_SIZE - ls) * sizeof(int8_t));
srslte_vec_i8_zero(&vp->soft_bits.c[ini + ls], SRSLTE_AVX512_B_SIZE - ls);
ini = ini + SRSLTE_AVX512_B_SIZE;
}
bzero(vp->check_to_var, (vp->hrr + 1) * vp->bgM * sizeof(__m512i));
bzero(vp->var_to_check, (vp->hrr + 1) * sizeof(__m512i));
SRSLTE_MEM_ZERO(vp->check_to_var, __m512i, (vp->hrr + 1) * vp->bgM);
SRSLTE_MEM_ZERO(vp->var_to_check, __m512i, vp->hrr + 1);
return 0;
}
@ -231,7 +233,7 @@ int extract_ldpc_message_c_avx512(void* p, uint8_t* message, uint16_t liftK)
int ini = 0;
for (int i = 0; i < liftK; i = i + vp->ls) {
for (int k = 0; k < vp->ls; k++) {
message[i + k] = (vp->soft_bits->c[ini + k] < 0);
message[i + k] = (vp->soft_bits.c[ini + k] < 0);
}
ini = ini + SRSLTE_AVX512_B_SIZE;
}
@ -250,15 +252,12 @@ int update_ldpc_var_to_check_c_avx512(void* p, int i_layer)
__m512i* this_check_to_var = vp->check_to_var + i_layer * (vp->hrr + 1);
// Update the high-rate region.
inner_var_to_check_c_avx512(&(vp->soft_bits[0].v), this_check_to_var, vp->var_to_check, infinity7, vp->hrr);
inner_var_to_check_c_avx512(vp->soft_bits.v, this_check_to_var, vp->var_to_check, infinity7, vp->hrr);
if (i_layer >= 4) {
// Update the extension region.
inner_var_to_check_c_avx512(&(vp->soft_bits[0].v) + vp->hrr + i_layer - 4,
this_check_to_var + vp->hrr,
vp->var_to_check + vp->hrr,
infinity7,
1);
inner_var_to_check_c_avx512(
vp->soft_bits.v + vp->hrr + i_layer - 4, this_check_to_var + vp->hrr, vp->var_to_check + vp->hrr, infinity7, 1);
}
return 0;
@ -382,7 +381,7 @@ int update_ldpc_soft_bits_c_avx512(void* p, int i_layer, const int8_t (*these_va
mask_epi8 = _mm512_cmpgt_epi8_mask(_mm512_neg_infty7_epi8, tmp_epi8);
vp->soft_bits[current_var_index].v = _mm512_mask_blend_epi8(mask_epi8, tmp_epi8, _mm512_neg_infty8_epi8);
vp->soft_bits.v[current_var_index] = _mm512_mask_blend_epi8(mask_epi8, tmp_epi8, _mm512_neg_infty8_epi8);
current_var_index = (*these_var_indices)[i + 1];
}

@ -39,15 +39,15 @@
* \brief Represents a node of the base factor graph.
*/
typedef union bg_node_t {
uint8_t c[SRSLTE_AVX2_B_SIZE]; /*!< Each base node may contain up to \ref SRSLTE_AVX2_B_SIZE lifted nodes. */
__m256i v; /*!< All the lifted nodes of the current base node as a 256-bit line. */
uint8_t* c; /*!< Each base node may contain up to \ref SRSLTE_AVX2_B_SIZE lifted nodes. */
__m256i* v; /*!< All the lifted nodes of the current base node as a 256-bit line. */
} bg_node_t;
/*!
* \brief Inner registers for the optimized LDPC encoder.
*/
struct ldpc_enc_avx2 {
bg_node_t* codeword; /*!< \brief Contains the entire codeword, before puncturing. */
bg_node_t codeword; /*!< \brief Contains the entire codeword, before puncturing. */
__m256i* aux; /*!< \brief Auxiliary register. */
};
@ -95,18 +95,17 @@ void* create_ldpc_enc_avx2(srslte_ldpc_encoder_t* q)
{
struct ldpc_enc_avx2* vp = NULL;
if ((vp = malloc(sizeof(struct ldpc_enc_avx2))) == NULL) {
if ((vp = SRSLTE_MEM_ALLOC(struct ldpc_enc_avx2, 1)) == NULL) {
return NULL;
}
if ((vp->codeword = srslte_vec_malloc(q->bgN * sizeof(bg_node_t))) == NULL) {
free(vp);
if ((vp->codeword.v = SRSLTE_MEM_ALLOC(__m256i, q->bgN)) == NULL) {
delete_ldpc_enc_avx2(vp);
return NULL;
}
if ((vp->aux = srslte_vec_malloc(q->bgM * sizeof(__m256i))) == NULL) {
free(vp->codeword);
free(vp);
if ((vp->aux = SRSLTE_MEM_ALLOC(__m256i, q->bgM)) == NULL) {
delete_ldpc_enc_avx2(vp);
return NULL;
}
@ -117,11 +116,16 @@ void delete_ldpc_enc_avx2(void* p)
{
struct ldpc_enc_avx2* vp = p;
if (vp != NULL) {
if (vp == NULL) {
return;
}
if (vp->aux) {
free(vp->aux);
free(vp->codeword);
free(vp);
}
if (vp->codeword.v) {
free(vp->codeword.v);
}
free(vp);
}
int load_avx2(void* p, const uint8_t* input, const uint8_t msg_len, const uint8_t cdwd_len, const uint16_t ls)
@ -136,14 +140,14 @@ int load_avx2(void* p, const uint8_t* input, const uint8_t msg_len, const uint8_
int node_size = SRSLTE_AVX2_B_SIZE;
for (int i = 0; i < msg_len * ls; i = i + ls) {
for (int k = 0; k < ls; k++) {
vp->codeword->c[ini + k] = input[i + k];
vp->codeword.c[ini + k] = input[i + k];
}
// this zero padding can be removed
bzero(&(vp->codeword->c[ini + ls]), (node_size - ls) * sizeof(uint8_t));
srslte_vec_u8_zero(&vp->codeword.c[ini + ls], node_size - ls);
ini = ini + node_size;
}
bzero(vp->codeword + msg_len, (cdwd_len - msg_len) * sizeof(__m256i));
SRSLTE_MEM_ZERO(vp->codeword.v + msg_len, __m256i, cdwd_len - msg_len);
return 0;
}
@ -159,7 +163,7 @@ int return_codeword_avx2(void* p, uint8_t* output, const uint8_t cdwd_len, const
int ini = SRSLTE_AVX2_B_SIZE + SRSLTE_AVX2_B_SIZE;
for (int i = 0; i < (cdwd_len - 2) * ls; i = i + ls) {
for (int k = 0; k < ls; k++) {
output[i + k] = vp->codeword->c[ini + k];
output[i + k] = vp->codeword.c[ini + k];
}
ini = ini + SRSLTE_AVX2_B_SIZE;
}
@ -184,14 +188,14 @@ void encode_ext_region_avx2(srslte_ldpc_encoder_t* q, uint8_t n_layers)
skip = q->bgK + m;
// the systematic part has already been computed
vp->codeword[skip].v = vp->aux[m];
vp->codeword.v[skip] = vp->aux[m];
// sum the contribution due to the high-rate region, with the proper circular shifts
for (k = 0; k < 4; k++) {
this_shift = q->pcm + q->bgK + k + m * q->bgN;
if (*this_shift != NO_CNCT) {
tmp_epi8 = rotate_node_right(vp->codeword[q->bgK + k].v, *this_shift, q->ls);
vp->codeword[skip].v = _mm256_xor_si256(vp->codeword[skip].v, tmp_epi8);
tmp_epi8 = rotate_node_right(vp->codeword.v[q->bgK + k], *this_shift, q->ls);
vp->codeword.v[skip] = _mm256_xor_si256(vp->codeword.v[skip], tmp_epi8);
}
}
}
@ -228,7 +232,7 @@ void preprocess_systematic_bits_avx2(srslte_ldpc_encoder_t* q)
// xor array aux[m] with a circularly shifted version of the current input chunk, unless
// the current check node and variable node are not connected.
if (*this_shift != NO_CNCT) {
tmp_epi8 = rotate_node_right(vp->codeword[k].v, *this_shift, ls);
tmp_epi8 = rotate_node_right(vp->codeword.v[k], *this_shift, ls);
tmp_epi8 = _mm256_and_si256(tmp_epi8, one_epi8);
vp->aux[m] = _mm256_xor_si256(vp->aux[m], tmp_epi8);
}
@ -249,17 +253,17 @@ void encode_high_rate_case1_avx2(void* o)
int skip3 = q->bgK + 3;
// first chunk of parity bits
vp->codeword[skip0].v = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
vp->codeword[skip0].v = _mm256_xor_si256(vp->codeword[skip0].v, vp->aux[2]);
vp->codeword[skip0].v = _mm256_xor_si256(vp->codeword[skip0].v, vp->aux[3]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->codeword.v[skip0], vp->aux[2]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->codeword.v[skip0], vp->aux[3]);
__m256i tmp_epi8 = rotate_node_right(vp->codeword[skip0].v, 1, ls);
__m256i tmp_epi8 = rotate_node_right(vp->codeword.v[skip0], 1, ls);
// second chunk of parity bits
vp->codeword[skip1].v = _mm256_xor_si256(vp->aux[0], tmp_epi8);
vp->codeword.v[skip1] = _mm256_xor_si256(vp->aux[0], tmp_epi8);
// fourth chunk of parity bits
vp->codeword[skip3].v = _mm256_xor_si256(vp->aux[3], tmp_epi8);
vp->codeword.v[skip3] = _mm256_xor_si256(vp->aux[3], tmp_epi8);
// third chunk of parity bits
vp->codeword[skip2].v = _mm256_xor_si256(vp->aux[2], vp->codeword[skip3].v);
vp->codeword.v[skip2] = _mm256_xor_si256(vp->aux[2], vp->codeword.v[skip3]);
}
void encode_high_rate_case2_avx2(void* o)
@ -278,14 +282,14 @@ void encode_high_rate_case2_avx2(void* o)
__m256i tmp_epi8 = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
tmp_epi8 = _mm256_xor_si256(tmp_epi8, vp->aux[2]);
tmp_epi8 = _mm256_xor_si256(tmp_epi8, vp->aux[3]);
vp->codeword[skip0].v = rotate_node_left(tmp_epi8, 105 % ls, ls);
vp->codeword.v[skip0] = rotate_node_left(tmp_epi8, 105 % ls, ls);
// second chunk of parity bits
vp->codeword[skip1].v = _mm256_xor_si256(vp->aux[0], vp->codeword[skip0].v);
vp->codeword.v[skip1] = _mm256_xor_si256(vp->aux[0], vp->codeword.v[skip0]);
// fourth chunk of parity bits
vp->codeword[skip3].v = _mm256_xor_si256(vp->aux[3], vp->codeword[skip0].v);
vp->codeword.v[skip3] = _mm256_xor_si256(vp->aux[3], vp->codeword.v[skip0]);
// third chunk of parity bits
vp->codeword[skip2].v = _mm256_xor_si256(vp->aux[2], vp->codeword[skip3].v);
vp->codeword.v[skip2] = _mm256_xor_si256(vp->aux[2], vp->codeword.v[skip3]);
}
void encode_high_rate_case3_avx2(void* o)
@ -304,14 +308,14 @@ void encode_high_rate_case3_avx2(void* o)
__m256i tmp_epi8 = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
tmp_epi8 = _mm256_xor_si256(tmp_epi8, vp->aux[2]);
tmp_epi8 = _mm256_xor_si256(tmp_epi8, vp->aux[3]);
vp->codeword[skip0].v = rotate_node_left(tmp_epi8, 1, ls);
vp->codeword.v[skip0] = rotate_node_left(tmp_epi8, 1, ls);
// second chunk of parity bits
vp->codeword[skip1].v = _mm256_xor_si256(vp->aux[0], vp->codeword[skip0].v);
vp->codeword.v[skip1] = _mm256_xor_si256(vp->aux[0], vp->codeword.v[skip0]);
// third chunk of parity bits
vp->codeword[skip2].v = _mm256_xor_si256(vp->aux[1], vp->codeword[skip1].v);
vp->codeword.v[skip2] = _mm256_xor_si256(vp->aux[1], vp->codeword.v[skip1]);
// fourth chunk of parity bits
vp->codeword[skip3].v = _mm256_xor_si256(vp->aux[3], vp->codeword[skip0].v);
vp->codeword.v[skip3] = _mm256_xor_si256(vp->aux[3], vp->codeword.v[skip0]);
}
void encode_high_rate_case4_avx2(void* o)
@ -327,17 +331,17 @@ void encode_high_rate_case4_avx2(void* o)
int skip3 = q->bgK + 3;
// first chunk of parity bits
vp->codeword[skip0].v = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
vp->codeword[skip0].v = _mm256_xor_si256(vp->codeword[skip0].v, vp->aux[2]);
vp->codeword[skip0].v = _mm256_xor_si256(vp->codeword[skip0].v, vp->aux[3]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->aux[0], vp->aux[1]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->codeword.v[skip0], vp->aux[2]);
vp->codeword.v[skip0] = _mm256_xor_si256(vp->codeword.v[skip0], vp->aux[3]);
__m256i tmp_epi8 = rotate_node_right(vp->codeword[skip0].v, 1, ls);
__m256i tmp_epi8 = rotate_node_right(vp->codeword.v[skip0], 1, ls);
// second chunk of parity bits
vp->codeword[skip1].v = _mm256_xor_si256(vp->aux[0], tmp_epi8);
vp->codeword.v[skip1] = _mm256_xor_si256(vp->aux[0], tmp_epi8);
// third chunk of parity bits
vp->codeword[skip2].v = _mm256_xor_si256(vp->aux[1], vp->codeword[skip1].v);
vp->codeword.v[skip2] = _mm256_xor_si256(vp->aux[1], vp->codeword.v[skip1]);
// fourth chunk of parity bits
vp->codeword[skip3].v = _mm256_xor_si256(vp->aux[3], tmp_epi8);
vp->codeword.v[skip3] = _mm256_xor_si256(vp->aux[3], tmp_epi8);
}
static __m256i _mm256_rotatelli_si256(__m256i a, int imm)

Loading…
Cancel
Save