Skip to content
Matrix_Row.hpp 6.33 KiB
Newer Older
/*
 * Copyright (c) 2018 by Intinor AB. All rights reserved.
 *
 * This file is part of "libRaptorQ".
 *
 * libRaptorQ is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation, either version 3
 * of the License, or (at your option) any later version.
 *
 * libRaptorQ is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * and a copy of the GNU Lesser General Public License
 * along with libRaptorQ.  If not, see <http://www.gnu.org/licenses/>.
 */

#pragma once
#include "../external/delegates/include/delegate.h"
#include "RaptorQ/v1/util/CPU_Info.hpp"
#include <cstdlib>
#include <stdlib.h>

namespace RaptorQ__v1 {
namespace Impl {

typedef dlgt::delegate<void(*)(uint8_t *, uint8_t *, int)> add_delegate;
typedef dlgt::delegate<void(*)(uint8_t *, uint8_t, int)> div_delegate;
typedef dlgt::delegate<void(*)(uint8_t *, uint8_t *, uint8_t, int)>
                                                            add_mul_delegate;

namespace Matrix_Row_SIMD {
    // The SIMD versions are: dest[i] ^= (high[src[i]>>4] ^ low[src[i]&0xf]);
    // high and low are 16 byte lookup tables (high four bits and low four bits)

    // Note: Putting these outside the class as the gcc 4.9.4 linker
    // LTO got into real trouble otherwise.
    extern void add_avx2(uint8_t *dest, uint8_t *src, int bytes);
    extern void div_avx2(uint8_t *data, uint8_t num, int bytes);
    extern void multiply_and_add_avx2(uint8_t *data, uint8_t *src, uint8_t num,
                                                                    int bytes);

    extern void add_ssse3(uint8_t *dest, uint8_t *src, int bytes);
    extern void div_ssse3(uint8_t *data, uint8_t num, int bytes);
    extern void multiply_and_add_ssse3(uint8_t *data, uint8_t *src, uint8_t num,
                                                                    int bytes);
} // namespace Matrix_Row_SIMD

class RAPTORQ_LOCAL Matrix_Row
{
public:
    enum class SIMD : uint8_t {
        NONE = 0x00,
        AUTO = 0x01,
        SSSE3 = 0x02,
        AVX2 = 0x03,
    };

    static Matrix_Row & get_instance() {
        static Matrix_Row matrix;
        return matrix;
    }

    static void init_simd(SIMD type) {
        Matrix_Row::get_instance()._init_simd(type);
    }

    static void row_multiply_add(uint8_t *dest, uint8_t *src, uint8_t scalar,
                                                                    int bytes)
    {
        if (scalar == 1) {
            Matrix_Row::get_instance().add_del(dest, src, bytes);
            Matrix_Row::get_instance().add_mul_del(dest, src, scalar, bytes);
        }
    }

    static void row_div(uint8_t *data, uint8_t scalar, int bytes)
    {
        if (scalar == 0) {
            return;
        }
        Matrix_Row::get_instance().div_del(data, scalar, bytes);
    }

    static void row_add(uint8_t *dest, uint8_t *src, int bytes)
    {
        Matrix_Row::get_instance().add_del(dest, src, bytes);
    add_delegate add_del = dlgt::make_delegate(&add);
    div_delegate div_del = dlgt::make_delegate(&div);
    add_mul_delegate add_mul_del = dlgt::make_delegate(&multiply_and_add);

    Matrix_Row() {
        _init_simd(SIMD::AUTO);
    }

    // Ordered from most wanted to least wanted
    SIMD get_simd_type()
    {
        if (CPU_Info::has_avx2()) {
            return SIMD::AVX2;
        }
        if (CPU_Info::has_ssse3()) {
            return SIMD::SSSE3;
        }
        return SIMD::NONE;
    }

    bool has_simd_type(SIMD type)
    {
        bool return_value = false; //Better safe than sorry
        switch( type )
        {
        case SIMD::SSSE3:
            return_value = CPU_Info::has_ssse3();
            break;
        case SIMD::AVX2:
            return_value = CPU_Info::has_avx2();
            break;
        default:
            return_value = true;
            break;
        }
        return return_value;
    }

    void _init_simd(SIMD type) {
        if (type == SIMD::AUTO || !has_simd_type(type)) {
            type = get_simd_type();
        }

        switch( type )
        {
        case SIMD::SSSE3:
            add_del = dlgt::make_delegate(&Matrix_Row_SIMD::add_ssse3);
            div_del = dlgt::make_delegate(&Matrix_Row_SIMD::div_ssse3);
            add_mul_del = dlgt::make_delegate(
                                    &Matrix_Row_SIMD::multiply_and_add_ssse3);
            break;

        case SIMD::AVX2:
            add_del = dlgt::make_delegate(&Matrix_Row_SIMD::add_avx2);
            div_del = dlgt::make_delegate(&Matrix_Row_SIMD::div_avx2);
            add_mul_del = dlgt::make_delegate(
                                    &Matrix_Row_SIMD::multiply_and_add_avx2);
            add_del = dlgt::make_delegate(&add);
            div_del = dlgt::make_delegate(&div);
            add_mul_del = dlgt::make_delegate(&multiply_and_add);
            break;
        }
    }

    static void add(uint8_t *dest, uint8_t *src, int bytes)
    {
        for (int i = 0; i < bytes;i++) {
            dest[i] = dest[i] ^ src[i];
        }
    }

    static void div(uint8_t *data, uint8_t num, int bytes)
    {
        num = oct_log[num - 1];
        for (int i = 0; i < bytes;i++) {
            if (data[i] != 0) {
                data[i] = oct_exp[oct_log[data[i] - 1] - num + 255];
            }
        }
    }

    static void multiply_and_add(uint8_t *dest, uint8_t *src, uint8_t num,
                                                                    int bytes)
    {
        if (num == 0) {
            return;
        }
        // TODO: Probably faster with a single lookup based on num, a single
        // lookup require an additional lookup table of size 256 * 256.
        uint16_t log_num = oct_log_no_if[num];
        for (int i = 0; i < bytes;i++) {
            dest[i] = dest[i] ^ oct_exp_no_if[oct_log_no_if[src[i]] + log_num];
        }
    }
};
} // namespace Impl
} // namespace RaptorQ__v1