import numpy as np
from mbtrack2 import Bunch, Beam

def assert_attr_changed(element, 
                        bunch, 
                        attrs_changed=["xp", "yp", "delta"],
                        change=True):
    
    if isinstance(bunch, Bunch):
        assert_attr_changed_bunch(element, bunch, attrs_changed, change=change)
    elif isinstance(bunch, Beam):
        assert_attr_changed_beam(element, bunch, attrs_changed, change=change)
    else:
        raise TypeError

def assert_attr_changed_bunch(element, 
                              bunch, 
                              attrs_changed=["xp", "yp", "delta"],
                              change=True):
    
    attrs = ["x","xp","y","yp","tau","delta"]
    attrs_unchanged = [attr for attr in attrs if attr not in attrs_changed]
    
    initial_values_changed = {attr: bunch[attr].copy() for attr in attrs_changed}
    
    initial_values_unchanged = {attr: bunch[attr].copy() for attr in attrs_unchanged}
    
    element.track(bunch)

    for attr in attrs_changed:
        if change:
            assert not np.array_equal(initial_values_changed[attr], bunch[attr]), f"{attr}"
        else:
            assert np.array_equal(initial_values_changed[attr], bunch[attr]), f"{attr}"
        
    for attr in attrs_unchanged:
        assert np.array_equal(initial_values_unchanged[attr], bunch[attr]), f"{attr}"
        
def assert_attr_changed_beam(element, 
                             beam,
                             attrs_changed=["xp", "yp", "delta"],
                             change=True):
    
    attrs = ["x","xp","y","yp","tau","delta"]
    attrs_unchanged = [attr for attr in attrs if attr not in attrs_changed]
    
    initial_values_changed_b = [{attr: bunch[attr].copy() for attr in attrs_changed} for bunch in beam]
    
    initial_values_unchanged_b = [{attr: bunch[attr].copy() for attr in attrs_unchanged} for bunch in beam]
    
    element.track(beam)
    
    for i, bunch in enumerate(beam):
        initial_values_changed = initial_values_changed_b[i]
        initial_values_unchanged = initial_values_unchanged_b[i]
        for attr in attrs_changed:
            if change and (bunch.charge != 0):
                assert not np.array_equal(initial_values_changed[attr], bunch[attr]), f"{attr}"
            else:
                assert np.array_equal(initial_values_changed[attr], bunch[attr]), f"{attr}"
            
        for attr in attrs_unchanged:
            assert np.array_equal(initial_values_unchanged[attr], bunch[attr]), f"{attr}"