-- Block that compute a moving average from a AXIStream of data with ID
-- Results are published in a memory accessible from AXI-MM
library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

library desy;
use desy.ram_tdp;
use desy.math_signed.all;

entity moving_average is
    generic(
        G_W_ID          : natural;
        G_W_SIGNAL      : natural;
        G_W_ALPHA       : natural
    );
    port(
        clk             : in std_logic;
        rst_n           : in std_logic;

        -- Moving average alpha
        alpha           : in signed(G_W_ALPHA-1 downto 0);

        -- Signal input
        sig_data        : in signed(G_W_SIGNAL-1 downto 0);
        sig_id          : in std_logic_vector(G_W_ID-1 downto 0);
        sig_valid       : in std_logic;

        -- AXI-MM Average signal table
        asigt_en        : in std_logic;
        asigt_addr      : in std_logic_vector(G_W_ID-1 downto 0);
        asigt_rdata     : out std_logic_vector(G_W_SIGNAL-1 downto 0)
    );
end entity moving_average;

architecture rtl of moving_average is

    type arr_slv is array (natural range <>) of std_logic_vector;

    ------------------------
    -- SIGNAL DECLARATION --
    ------------------------
    signal table_wen        : std_logic;
    signal table_waddr      : std_logic_vector(G_W_ID-1 downto 0);
    signal table_wdata      : std_logic_vector(G_W_SIGNAL+G_W_ALPHA-1 downto 0);
    signal table_wdata_sec  : std_logic_vector(G_W_SIGNAL-1 downto 0);
    signal table_rdata      : std_logic_vector(G_W_SIGNAL+G_W_ALPHA-1 downto 0);

    signal sig_id_r         : arr_slv(0 to 2)(G_W_ID-1 downto 0);
    signal sig_valid_r      : std_logic_vector(0 to 2);

    signal pasig            : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);

    signal sig_alpha        : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);
    signal sig_alpha_r      : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);

    signal pasig_alpha      : signed(G_W_SIGNAL+G_W_ALPHA*2-1 downto 0);
    signal pasig_alpha_rnd  : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);
    signal ma_update        : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);
    signal ma_new           : signed(G_W_SIGNAL+G_W_ALPHA-1 downto 0);
    signal ma_new_rnd       : signed(G_W_SIGNAL-1 downto 0);


begin


    --------------------------------
    -- AVERAGE ORBIT ERROR TABLES --
    --------------------------------
    -- main one, used for computation. A reads only, B writes only

    inst_asig_main_table: entity desy.ram_tdp
    generic map(
        G_ADDR      => G_W_ID,
        G_DATA      => G_W_SIGNAL+G_W_ALPHA
    )
    port map(
        pi_clk_a    => clk,
        pi_en_a     => '1',
        pi_we_a     => '0',
        pi_addr_a   => sig_id,
        pi_data_a   => (others => '0'),
        po_data_a   => table_rdata,
        pi_clk_b    => clk,
        pi_en_b     => '1',
        pi_we_b     => table_wen,
        pi_addr_b   => table_waddr,
        pi_data_b   => table_wdata,
        po_data_b   => open
    );
    table_wdata <= std_logic_vector(ma_new);

    -- secondary one, used for axi access to result. A reads only, B writes only
    -- B port is mirror of previous bloc
    -- No need to retain decimal part

    inst_asig_sec_table: entity desy.ram_tdp
    generic map(
        G_ADDR      => G_W_ID,
        G_DATA      => G_W_SIGNAL
    )
    port map(
        pi_clk_a    => clk,
        pi_en_a     => asigt_en,
        pi_we_a     => '0',
        pi_addr_a   => asigt_addr,
        pi_data_a   => (others => '0'),
        po_data_a   => asigt_rdata,
        pi_clk_b    => clk,
        pi_en_b     => '1',
        pi_we_b     => table_wen,
        pi_addr_b   => table_waddr,
        pi_data_b   => table_wdata_sec,
        po_data_b   => open
    );

    -- Round before memorize
    table_wdata_sec <= std_logic_vector(ma_new_rnd);
    ma_new_rnd      <= f_resize_lsb(ma_new, ma_new_rnd'length) when ma_new(G_W_ALPHA-1) = '0' else
                       f_sum_sat(f_resize_lsb(ma_new, ma_new_rnd'length), to_signed(1, ma_new_rnd'length));


    ------------------------
    -- PIPELINE REGISTERS --
    ------------------------
    p_pipe:process(clk, rst_n)
    begin
        if rst_n = '0' then
            sig_id_r         <= (others => (others => '0'));
            sig_valid_r      <= (others => '0');
        elsif rising_edge(clk) then

            sig_valid_r <= sig_valid & sig_valid_r(0 to sig_valid_r'right-1);

            sig_id_r(0) <= sig_id;
            for I in 1 to sig_id_r'right loop
                sig_id_r(I) <= sig_id_r(I-1);
            end loop;

        end if;
    end process;

    table_wen       <= sig_valid_r(2);
    table_waddr     <= sig_id_r(2);


    -----------------
    -- COMPUTATION --
    -----------------
    pasig                <= signed(table_rdata);
    pasig_alpha          <= -pasig * alpha;
    pasig_alpha_rnd      <= f_resize_lsb(pasig_alpha, pasig_alpha_rnd'length) when pasig_alpha(G_W_ALPHA-1) = '0' else
                           f_sum_sat(f_resize_lsb(pasig_alpha, pasig_alpha_rnd'length), to_signed(1, pasig_alpha_rnd'length));

    p_comp:process(clk, rst_n)
    begin
        if rst_n = '0' then
            sig_alpha        <= (others => '0');
            sig_alpha_r      <= (others => '0');
            ma_update        <= (others => '0');
            ma_new           <= (others => '0');

        elsif rising_edge(clk) then

            sig_alpha        <= sig_data * alpha;
            sig_alpha_r      <= sig_alpha;

            ma_update        <= f_sum_sat(pasig, pasig_alpha_rnd);
            ma_new           <= f_sum_sat(sig_alpha_r, ma_update);

        end if;
    end process;



end architecture;