library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

library desy;
use desy.math_signed.all;

use work.pkg_corr_matrix.all;

entity corr_ll is
    port(
        clk             : in std_logic;
        rst_n           : in std_logic;


        coef_a_x        : in signed(C_W_COR_COEF-1 downto 0);
        coef_b_x        : in signed(C_W_COR_COEF-1 downto 0);
        coef_ic_x       : in signed(C_W_COR_COEF-1 downto 0);
        coef_d_x        : in signed(C_W_COR_COEF-1 downto 0);
        coef_a_y        : in signed(C_W_COR_COEF-1 downto 0);
        coef_b_y        : in signed(C_W_COR_COEF-1 downto 0);
        coef_ic_y       : in signed(C_W_COR_COEF-1 downto 0);
        coef_d_y        : in signed(C_W_COR_COEF-1 downto 0);

        enable_corr_x   : in std_logic;
        reset_corr_x    : in std_logic;
        enable_corr_y   : in std_logic;
        reset_corr_y    : in std_logic;

        matmult         : in signed_array(0 to C_N_MM_PSC-1)(C_W_MM-1 downto 0);
        matmult_valid   : in std_logic;
        matmult_seq     : in std_logic_vector(C_W_BPMSEQ-1 downto 0);

        corrout_valid   : out std_logic;
        corrout_seq     : out std_logic_vector(C_W_BPMSEQ-1 downto 0);
        corrout         : out signed_array(0 to C_N_MM_PSC-1)(C_W_COR-1 downto 0)
    );
end entity corr_ll;


architecture rtl of corr_ll is

    constant C_DELAY    : natural := 4;
    type arr_slv is array (natural range <>) of std_logic_vector;
    signal delay_valid  : std_logic_vector(C_DELAY-1 downto 0);
    signal delay_seq    : arr_slv(0 to C_DELAY-1)(C_W_BPMSEQ-1 downto 0);


begin

    ----------------------
    -- OUTPUT CONNEXION --
    ----------------------
    corrout_valid   <= delay_valid(delay_valid'left);
    corrout_seq     <= delay_seq(C_DELAY-1);

    --------------------
    -- DELAY REGISTER --
    --------------------
    p_delay:process(clk,rst_n)
    begin
        if rst_n = '0' then
            delay_valid    <= (others => '0');
            delay_seq      <= (others => (others => '0'));
        elsif rising_edge(clk) then
            delay_valid(0) <= matmult_valid;
            delay_valid(delay_valid'left downto 1) <= delay_valid(delay_valid'left-1 downto 0);

            delay_seq(0)   <= matmult_seq;
            for I in 1 to C_DELAY-1 loop
                delay_seq(I) <= delay_seq(I-1);
            end loop;
        end if;
    end process;

    --------------------
    -- CORRECTOR LINE --
    --------------------
    gen_corr:for I in 0 to C_N_MM_PSC-1 generate

        signal mult_a       : signed(C_W_MM+C_W_COR_COEF-1 downto 0);
        signal mult_b       : signed(C_W_MM+C_W_COR_COEF-1 downto 0);
        signal mult_d       : signed(C_W_COR+C_W_COR_COEF-1 downto 0);
        signal mult_ic      : signed(C_W_COR_SUMSAT+C_W_COR_COEF-1 downto 0);
        signal sum_abd      : signed(C_W_COR_SUMSAT-1 downto 0);
        signal rnd_abicd    : signed(C_W_COR_SUMSAT+C_W_COR_COEF-C_N_COR_RND-1 downto 0);
        signal reg_dout     : signed(C_W_COR-1 downto 0);
        signal reg_din      : signed(C_W_MM-1 downto 0);
        signal rz_mult_a    : signed(C_W_COR_SUMSAT-1 downto 0);
        signal rz_mult_b    : signed(C_W_COR_SUMSAT-1 downto 0);
        signal rz_mult_d    : signed(C_W_COR_SUMSAT-1 downto 0);

        signal coef_a       : signed(C_W_COR_COEF-1 downto 0);
        signal coef_b       : signed(C_W_COR_COEF-1 downto 0);
        signal coef_ic      : signed(C_W_COR_COEF-1 downto 0);
        signal coef_d       : signed(C_W_COR_COEF-1 downto 0);

        signal enable_corr  : std_logic;
        signal reset_corr   : std_logic;

    begin

        -- Signal mux for X/Y coef and enable/reset, based on PSCID number
        gen_x_sigmux:if I < 51 generate
            coef_a         <= coef_a_x;
            coef_b         <= coef_b_x;
            coef_ic        <= coef_ic_x;
            coef_d         <= coef_d_x;
            enable_corr    <= enable_corr_x;
            reset_corr     <= reset_corr_x;
        end generate;

        gen_y_sigmux:if I > 50 generate
            coef_a         <= coef_a_y;
            coef_b         <= coef_b_y;
            coef_ic        <= coef_ic_y;
            coef_d         <= coef_d_y;
            enable_corr    <= enable_corr_y;
            reset_corr     <= reset_corr_y;
        end generate;

        p_main:process(clk, rst_n)
        begin
            if rst_n = '0' then
                mult_a          <= (others => '0');
                mult_b          <= (others => '0');
                mult_ic         <= (others => '0');
                mult_d          <= (others => '0');
                sum_abd         <= (others => '0');
                reg_din         <= (others => '0');
                reg_dout        <= (others => '0');
            elsif rising_edge(clk) then

                if reset_corr = '1' then
                    reg_din <= (others => '0');
                else
                    if matmult_valid = '1' then
                        reg_din <= matmult(I);
                    end if;
                end if;

                mult_a  <= matmult(I) * coef_a;
                mult_b  <= reg_din * coef_b;
                mult_d  <= reg_dout * coef_d;



                sum_abd <= f_sum_sat(f_sum_sat(rz_mult_a, rz_mult_b), rz_mult_d);

                mult_ic <= sum_abd * coef_ic;

                if reset_corr = '1' then
                    reg_dout    <= (others => '0');
                else
                    if delay_valid(2) = '1' and enable_corr = '1' then
                        reg_dout        <= f_resize_sat(rnd_abicd, C_W_COR);
                    end if;
                end if;

            end if;
        end process;

        -- resize
        rz_mult_a   <= f_resize_sat(mult_a, C_W_COR_SUMSAT);
        rz_mult_b   <= f_resize_sat(mult_b, C_W_COR_SUMSAT);
        rz_mult_d   <= f_resize_sat(mult_d, C_W_COR_SUMSAT);

        -- round then sat
        rnd_abicd       <= f_resize_lsb(mult_ic, rnd_abicd'length) when mult_ic(C_N_COR_RND-1) = '0' else
                           f_sum_sat(f_resize_lsb(mult_ic, rnd_abicd'length), to_signed(1, rnd_abicd'length));

        -- mapping
        corrout(I) <= reg_dout;

    end generate gen_corr;

end architecture;