Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
matrix_mul.vhd 8.86 KiB
library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

library desy;
use desy.ram_tdp;
use desy.math_signed.all;

library desyrdl;
use desyrdl.pkg_corr_matrix.t_mem_MATRIXCOEF_2d_out;
use desyrdl.pkg_corr_matrix.t_mem_MATRIXCOEF_2d_in;

use work.pkg_corr_matrix.all;

entity matrix_mul is
    port(
        clk                    : in std_logic;
        rst_n                  : in std_logic;

        -- Coef table, desyrdl
        mm_coef_i              : in t_mem_MATRIXCOEF_2d_out;
        mm_coef_o              : out t_mem_MATRIXCOEF_2d_in;
        id_cnt_load            : in std_logic_vector(C_W_MM_IDCNT-1 downto 0);

        -- Position data in
        pos_x                  : in signed(C_W_OE-1 downto 0);
        pos_y                  : in signed(C_W_OE-1 downto 0);
        pos_id                 : in std_logic_vector(C_W_BPMID-1 downto 0);
        pos_seq                : in std_logic_vector(C_W_BPMSEQ-1 downto 0);
        pos_tvalid             : in std_logic;

        -- status
        mult_rate              : out std_logic_vector(15 downto 0);
        pps                    : in std_logic;

        -- Data out
        matmult                : out signed_array(0 to C_N_MM_PSC-1)(C_W_MM-1 downto 0);
        matmult_tvalid         : out std_logic;
        matmult_seq            : out std_logic_vector(C_W_BPMSEQ-1 downto 0)

        );
end entity;

architecture rtl of matrix_mul is

    type arr_slv is array (natural range <>) of std_logic_vector;
    ------------------------
    -- SIGNAL DECLARATION --
    ------------------------
    -- delay registers
    signal r_pos_x    : signed(pos_x'left downto 0);
    signal r_pos_y    : signed(pos_y'left downto 0);
    signal r_seq      : std_logic_vector(C_W_BPMSEQ-1 downto 0);
    signal r_tvalid   : std_logic_vector(3 downto 0);


    -- Accumulators general control
    signal rst_accu     : std_logic;
    signal ena_accu     : std_logic;

    signal id_cnt       : unsigned(C_W_MM_IDCNT-1 downto 0);

    signal new_seq : std_logic;
    signal mul_done : std_logic;
    signal r_mul_done : std_logic;

    signal mul_cnt      : unsigned(15 downto 0);
    signal r_mul_cnt    : std_logic_vector(15 downto 0);
    signal pps_r        : std_logic;
begin

    -----------------------
    -- MULT RATE COUNTER --
    -----------------------
    p_mult_rate:process(clk, rst_n)
    begin
        if rst_n = '0' then
            r_mul_cnt   <= (others => '0');
            mul_cnt     <= (others => '0');
            pps_r       <= '0';
        elsif rising_edge(clk) then
            pps_r   <= pps;
            if pps = '1' and pps_r = '0' then
                r_mul_cnt   <= std_logic_vector(mul_cnt);
                mul_cnt     <= (others => '0');
            else
                if mul_done = '1' then
                    mul_cnt <= mul_cnt +1;
                end if;
            end if;
        end if;
    end process;
    mult_rate   <= r_mul_cnt;

    ---------------------
    -- DELAY REGISTERS --
    ---------------------
    p_reg:process(clk, rst_n)
    begin
        if rst_n = '0' then
            r_pos_x       <= (others => '0');
            r_pos_y       <= (others => '0');
            r_seq         <= (others => '1');
            r_tvalid      <= (others => '0');
        elsif rising_edge(clk) then
            r_pos_x         <= pos_x;
            r_pos_y         <= pos_y;

            if pos_tvalid = '1' then
                r_seq        <= pos_seq;
            end if;

            r_tvalid        <= r_tvalid(r_tvalid'left-1 downto 0) & pos_tvalid;
        end if;
    end process;

    ena_accu <= r_tvalid(1);

    ----------------
    -- SEQ DETECT --
    ----------------
    new_seq     <= pos_tvalid when pos_seq /= r_seq else '0';
    rst_accu    <= new_seq;

    ----------------
    -- ID COUNTER --
    ----------------
    p_idcnt:process(clk, rst_n)
    begin
        if rst_n = '0' then
            id_cnt <= (others => '1');
            mul_done <= '0';
            r_mul_done <= '0';
        elsif rising_edge(clk) then
            r_mul_done <= mul_done;
            if id_cnt = 0 then
                id_cnt <= unsigned(id_cnt_load);
                mul_done <= '1';
            else
                mul_done <= '0';
                if new_seq= '1' then
                    id_cnt <= unsigned(id_cnt_load);
                else
                    if r_tvalid(0) = '1' then
                        id_cnt <= id_cnt - 1;
                    end if;
                end if;
            end if;
        end if;
    end process;

    ---------------------------
    -- MATRIX MULTIPLICATION --
    ---------------------------
    -- Generate matrix line multiplication, two planes by loop iteration
    G_MATRIX:for I in 0 to C_N_MM_PSC/2-1 generate
        signal mult_x   : signed(C_W_MM_MULT-1 downto 0);
        signal mult_y   : signed(C_W_MM_MULT-1 downto 0);
        signal accu_x   : signed(C_W_MM_ACCU-1 downto 0);
        signal accu_y   : signed(C_W_MM_ACCU-1 downto 0);
        signal table_coefx  : std_logic_vector(C_W_MM_COEF-1 downto 0);
        signal table_coefy  : std_logic_vector(C_W_MM_COEF-1 downto 0);


    begin

        ------------------------------------------------------------------
        -- COEF TABLES
        inst_coefx_table: entity desy.ram_tdp
        generic map(
            G_ADDR  => C_W_MM_IDCNT,
            G_DATA  => C_W_MM_COEF
        )
        port map(
            pi_clk_a    => clk,
            pi_en_a     => mm_coef_i(I).en,
            pi_we_a     => mm_coef_i(I).we,
            pi_addr_a   => mm_coef_i(I).addr(C_W_MM_IDCNT-1 downto 0),
            pi_data_a   => mm_coef_i(I).data(C_W_MM_COEF-1 downto 0),
            po_data_a   => mm_coef_o(I).data(C_W_MM_COEF-1 downto 0),
            pi_clk_b    => clk,
            pi_en_b     => '1',
            pi_we_b     => '0',
            pi_addr_b   => pos_id(C_W_MM_IDCNT-1 downto 0),
            pi_data_b   => (others => '0'),
            po_data_b   => table_coefx
        );
        mm_coef_o(I).data(31 downto C_W_MM_COEF) <= (others => '0');

        inst_coefy_table: entity desy.ram_tdp
        generic map(
            G_ADDR  => C_W_MM_IDCNT,
            G_DATA  => C_W_MM_COEF
        )
        port map(
            pi_clk_a    => clk,
            pi_en_a     => mm_coef_i(I+C_N_MM_PSC/2).en,
            pi_we_a     => mm_coef_i(I+C_N_MM_PSC/2).we,
            pi_addr_a   => mm_coef_i(I+C_N_MM_PSC/2).addr(C_W_MM_IDCNT-1 downto 0),
            pi_data_a   => mm_coef_i(I+C_N_MM_PSC/2).data(C_W_MM_COEF-1 downto 0),
            po_data_a   => mm_coef_o(I+C_N_MM_PSC/2).data(C_W_MM_COEF-1 downto 0),
            pi_clk_b    => clk,
            pi_en_b     => '1',
            pi_we_b     => '0',
            pi_addr_b   => pos_id(C_W_MM_IDCNT-1 downto 0),
            pi_data_b   => (others => '0'),
            po_data_b   => table_coefy
        );
        mm_coef_o(I+C_N_MM_PSC/2).data(31 downto C_W_MM_COEF) <= (others => '0');

        ------------------------------------------------------------------
        -- MULT ACCU
        p_multaccu:process(clk, rst_n)
        begin
            if rst_n = '0' then
                mult_x    <= (others => '0');
                accu_x    <= (others => '0');
                mult_y    <= (others => '0');
                accu_y    <= (others => '0');
            elsif rising_edge(clk) then

                mult_x <= r_pos_x * signed(table_coefx);
                mult_y <= r_pos_y * signed(table_coefy);

                if rst_accu = '1' then
                    accu_x    <= (others => '0');
                    accu_y    <= (others => '0');
                elsif ena_accu = '1' then
                    accu_x  <= f_sum_sat(accu_x, mult_x);
                    accu_y  <= f_sum_sat(accu_y, mult_y);
                end if;
            end if;
        end process;

        ------------------------------------------------------------------
        -- ROUND, SATURATE AND MAP RESULT
        p_rndsat:process(clk, rst_n)
        begin
            if rst_n = '0' then
                matmult(I)              <= (others => '0');
                matmult(I+C_N_MM_PSC/2) <= (others => '0');
            elsif rising_edge(clk) then
                matmult(I)              <= f_resize_sat(f_resize_lsb(accu_x, C_W_MM_ACCU-C_N_MM_RND), C_W_MM) when accu_x(C_N_MM_RND-1) = '0' else
                                           f_resize_sat(f_sum_sat(f_resize_lsb(accu_x, C_W_MM_ACCU-C_N_MM_RND), to_signed(1, C_W_MM_ACCU-C_N_MM_RND)), C_W_MM);
                matmult(I+C_N_MM_PSC/2) <= f_resize_sat(f_resize_lsb(accu_y, C_W_MM_ACCU-C_N_MM_RND), C_W_MM) when accu_y(C_N_MM_RND-1) = '0' else
                                           f_resize_sat(f_sum_sat(f_resize_lsb(accu_y, C_W_MM_ACCU-C_N_MM_RND), to_signed(1, C_W_MM_ACCU-C_N_MM_RND)), C_W_MM);
            end if;
        end process;

    end generate;

    --------------------
    -- OUTPUT CONNECT --
    --------------------
    matmult_tvalid  <= r_mul_done;
    matmult_seq     <= r_seq;

end architecture;