# -*- coding: utf-8 -*-
"""
This module defines general classes to describes wakefields, impedances and 
wake functions. Based on impedance library by David Amorim.

@author: David Amorin, Alexis Gamelin
@date: 14/01/2020
"""

import warnings
import re
import pandas as pd
import numpy as np
from scipy.interpolate import interp1d
from tracking.element import Element


class ComplexData:
    """
    Define a general data structure for a complex function based on a pandas 
    DataFrame.

    Parameters
    ----------
    variable : list, numpy array
       contains the function variable values

    function : list or numpy array of complex numbers
        contains the values taken by the complex function
    """

    def __init__(self, variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0])):
        self.data = pd.DataFrame({'real': np.real(function),
                                  'imag': np.imag(function)},
                                 index=variable)
        self.data.index.name = 'variable'

    def add(self, structure_to_add, method='zero', interp_kind='cubic', 
            index_name="variable"):
        """
        Method to add two structures. If the data don't have the same length,
        different cases are handled.
        
        Parameters
        ----------
        structure_to_add : ComplexData object, int, float or complex.
            structure to be added.
        method : str, optional
            manage how the addition is done, possibilties are: 
            -common: the common range of the index (variable) from the two
                structures is used. The input data are cross-interpolated
                so that the result data contains the points at all initial
                points in the common domain.
            -extrapolate: extrapolate the value of both ComplexData.
            -zero: outside of the common range, the missing values are zero. In 
                the common range, behave as "common" method.
        interp_kind : str, optional
            interpolation method which is passed to pandas and to
            scipy.interpolate.interp1d.
        index_name : str, optional
            name of the dataframe index passed to the method
            
        Returns
        -------
        ComplexData 
            Contains the sum of the two inputs.
        """

        # Create first two new DataFrames merging the variable
        # from the two input data

        if isinstance(structure_to_add, (int, float, complex)):
            structure_to_add = ComplexData(variable=self.data.index,
                                           function=(structure_to_add * 
                                                     np.ones(len(self.data.index))))
                                
        data_to_concat = structure_to_add.data.index.to_frame().set_index(index_name)
        
        initial_data = pd.concat([self.data, data_to_concat], sort=True)
        initial_data = initial_data[~initial_data.index.duplicated(keep='first')]

        data_to_add = pd.concat(
                        [structure_to_add.data,
                         self.data.index.to_frame().set_index(index_name)],
                        sort=True)
        data_to_add = data_to_add[~data_to_add.index.duplicated(keep='first')]

        if method == 'common':
            max_variable = min(structure_to_add.data.index.max(),
                               self.data.index.max())

            min_variable = max(structure_to_add.data.index.min(),
                               self.data.index.min())

            initial_data = initial_data.interpolate(method=interp_kind)
            mask = ((initial_data.index <= max_variable)
                    & (initial_data.index >= min_variable))
            initial_data = initial_data[mask]

            data_to_add = data_to_add.interpolate(method=interp_kind)
            mask = ((data_to_add.index <= max_variable)
                    & (data_to_add.index >= min_variable))
            data_to_add = data_to_add[mask]

            result_structure = ComplexData()
            result_structure.data = initial_data + data_to_add
            return result_structure

        if method == 'extrapolate':
            print('Not there yet')
            return self
        
        if method == 'zero':
            max_variable = min(structure_to_add.data.index.max(),
                               self.data.index.max())

            min_variable = max(structure_to_add.data.index.min(),
                               self.data.index.min())

            mask = ((initial_data.index <= max_variable)
                    & (initial_data.index >= min_variable))
            initial_data[mask] = initial_data[mask].interpolate(method=interp_kind)

            mask = ((data_to_add.index <= max_variable)
                    & (data_to_add.index >= min_variable))
            data_to_add[mask] = data_to_add[mask].interpolate(method=interp_kind)
            
            initial_data.replace(to_replace=np.nan, value=0, inplace=True)
            data_to_add.replace(to_replace=np.nan, value=0, inplace=True)
            
            result_structure = ComplexData()
            result_structure.data = initial_data + data_to_add
            return result_structure

    def __radd__(self, structure_to_add):
        return self.add(structure_to_add, method='zero')

    def __add__(self, structure_to_add):
        return self.add(structure_to_add, method='zero')

    def multiply(self, factor):
        """
        Multiply a data strucure with a float or an int.
        If the multiplication is done with something else, throw a warning.
        """
        if isinstance(factor, (int, float)):
            result_structure = ComplexData()
            result_structure.data = self.data * factor
            return result_structure
        else:
            warnings.warn(('The multiplication factor is not a float '
                           'or an int.'), UserWarning)
            return self

    def __mul__(self, factor):
        return self.multiply(factor)

    def __rmul__(self, factor):
        return self.multiply(factor)
    
    def __call__(self, values, interp_kind="cubic"):
        """
        Interpolation of the data by calling the class to have a function-like
        behaviour.
        
        Parameters
        ----------
        values : list or numpy array of complex, int or float
            values to be interpolated.
        interp_kind : str, optional
            interpolation method which is passed to scipy.interpolate.interp1d.
            
        Returns
        -------
        numpy array 
            Contains the interpolated data.
        """
        real_func = interp1d(x = self.data.index, 
                             y = self.data["real"], kind=interp_kind)
        imag_func = interp1d(x = self.data.index, 
                             y = self.data["imag"], kind=interp_kind)
        return real_func(values) + 1j*imag_func(values)
    
class WakeFunction(ComplexData):
    """
    Define a WakeFunction object based on a ComplexData object.
    """

    def __init__(self,
                 variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0]), wake_type='long'):
        super().__init__(variable, function)
        pass
    
class Impedance(ComplexData):
    """
    Define an Impedance object based on a ComplexData object.
    """

    def __init__(self,
                 variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0]), impedance_type='long'):
        super().__init__(variable, function)
        self._impedance_type = impedance_type
        self.data.index.name = 'frequency [Hz]'
        self.initialize_impedance_coefficient()


    def initialize_impedance_coefficient(self):
        """
        Define the impedance coefficients and the plane of the impedance that
        are presents in attributes of the class.
        """
        table = self.impedance_name_and_coefficients_table()
        
        try:
            component_coefficients = table[self.impedance_type].to_dict()
        except KeyError:
            print('Impedance type {} does not exist'.format(self.impedance_type))
        
        self.a = component_coefficients["a"]
        self.b = component_coefficients["b"]
        self.c = component_coefficients["c"]
        self.d = component_coefficients["d"]
        self.plane = component_coefficients["plane"]
                    
    def impedance_name_and_coefficients_table(self):
        """
        Return a table associating the human readbale names of an impedance
        component and its associated coefficients and plane.
        """

        component_dict = {
            'long': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'z'},
            'xcst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
            'ycst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
            'xdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
            'ydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'y'},
            'xydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'x'},
            'yxdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
            'xquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'x'},
            'yquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'y'},
            'xyquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'x'},
            'yxquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'y'},
            }

        return pd.DataFrame(component_dict)

    def add(self, structure_to_add, beta_x=1, beta_y=1, method='zero'):
        """
        Method to add two Impedance objects. The two structures are
        compared so that the addition is self-consistent.
        Beta functions can be precised as well.
        """

        if isinstance(structure_to_add, (int, float, complex)):
            result = super().add(structure_to_add, method=method, 
                                 index_name="frequency [Hz]")
        elif isinstance(structure_to_add, Impedance):
            if (self.impedance_type == structure_to_add.impedance_type):
                weight = (beta_x ** self.power_x) * (beta_y ** self.power_y)
                result = super().add(structure_to_add * weight, method=method, 
                                     index_name="frequency [Hz]")
            else:
                warnings.warn(('The two Impedance objects do not have the '
                               'same coordinates or plane or type. '
                               'Returning initial Impedance object.'),
                              UserWarning)
                result = self
                
        impedance_to_return = Impedance(
                                result.data.index,
                                result.data.real.values + 1j*result.data.imag.values,
                                self.impedance_type)    
        return impedance_to_return

    def __radd__(self, structure_to_add):
        return self.add(structure_to_add)

    def __add__(self, structure_to_add):
        return self.add(structure_to_add)

    def multiply(self, factor):
        """
        Multiply a Impedance object with a float or an int.
        If the multiplication is done with something else, throw a warning.
        """
        result = super().multiply(factor)
        impedance_to_return = Impedance(
                                result.data.index,
                                result.data.real.values + 1j*result.data.imag.values,
                                self.impedance_type)    
        return impedance_to_return

    def __mul__(self, factor):
        return self.multiply(factor)

    def __rmul__(self, factor):
        return self.multiply(factor)
        
    @property
    def impedance_type(self):
        return self._impedance_type
    
    @impedance_type.setter
    def impedance_type(self, value):
        self._impedance_type = value
        self.initialize_impedance_coefficient()

    @property
    def power_x(self):
        power_x = self.a/2 + self.c/2.
        if self.plane == 'x':
            power_x += 1./2.
        return power_x

    @property
    def power_y(self):
        power_y = self.b/2. + self.d/2.
        if self.plane == 'y':
            power_y += 1./2.
        return power_y
    
class Wakefield:
    """
    Defines a Wakefield which corresponds to a single physical element which 
    produces different types of wakes, represented by their Impedance or 
    WakeFunction objects.
    
    Parameters
    ----------
    structure_list : list, optional
        list of Impedance/WakeFunction objects to add to the Wakefield.
    name : str, optional
    
    Attributes
    ----------
    impedance_components : array of str
        Impedance component names present for this element.
        
    Methods
    -------
    append_to_model(structure_to_add)
        Add Impedance/WakeFunction to Wakefield.
    list_to_attr(structure_list)
        Add list of Impedance/WakeFunction to Wakefield.
    """

    def __init__(self, structure_list=None, name=None):
        self.list_to_attr(structure_list)
        self.name = name

    def append_to_model(self, structure_to_add):
        """
        Add Impedance/WakeFunction component to Wakefield.

        Parameters
        ----------
        structure_to_add : Impedance or WakeFunction object
        """
        list_of_attributes = dir(self)
        if isinstance(structure_to_add, Impedance):
            attribute_name = "Z" + structure_to_add.impedance_type
            if attribute_name in list_of_attributes:
                raise ValueError("There is already a component of the type "
                                 "{} in this element.".format(attribute_name))
            else:
                self.__setattr__(attribute_name, structure_to_add)
        elif isinstance(structure_to_add, WakeFunction):
            attribute_name = "W" + structure_to_add.impedance_type
            if attribute_name in list_of_attributes:
                raise ValueError("There is already a component of the type "
                                 "{} in this element.".format(attribute_name))
            else:
                self.__setattr__(attribute_name, structure_to_add)
        else:
            raise ValueError("{} is not an Impedance nor a WakeFunction.".format(structure_to_add))
    
    def list_to_attr(self, structure_list):
        """
         Add list of Impedance/WakeFunction components to Wakefield.

        Parameters
        ----------
        structure_list : list of Impedance or WakeFunction objects.
        """
        if structure_list is not None:
            for component in structure_list:
                self.append_to_model(component)
    
    @property
    def impedance_components(self):
        """
        Return an array of the impedance component names for the element.
        """
        return np.array([comp for comp in dir(self) if re.match(r'[Z]', comp)])
    
    
class ImpedanceModel(Element):
    """
    Define the impedance model of the machine.
    
    Parameters
    ----------
    ring : Synchrotron object
    wakefield_list : list of Wakefield objects
        Wakefields to add to the model.
    wakefiled_postions : list
        Longitudinal positions corresponding to the added Wakfields.
        
    
    """
    
    def __init__(self, ring, wakefield_list=None, wakefiled_postions=None):
        self.ring = ring
        self.optics = self.ring.optics
        self.wakefields = []
        self.positions = np.array([])
        self.add(wakefield_list, wakefiled_postions)
        
    def track(self, beam):
        """
        Track a beam object through this Element.
        
        Parameters
        ----------
        beam : Beam object
        """
        raise NotImplementedError
        
    def sum_elements(self):
        """
        Sum all the elements in the model into sum_wakefield.
        """
        beta = self.optics.beta(self.positions)
        list_impedance_components = ["Zlong","Zxdip","Zydip","Zxquad","Zyquad"]
        sum_wakefield = Wakefield()
        for component_name in list_impedance_components:
            sum_imp = Impedance(variable=np.array([0,1]), 
                                function=np.array([0, 0]),
                                impedance_type=component_name[1:])
            for index, wakefield in enumerate(self.wakefields):
                try:
                    impedance = wakefield.__getattribute__(component_name)
                    weight = ((beta[0,index] ** impedance.power_x) * 
                              (beta[1,index] ** impedance.power_y))
                    sum_imp += weight*impedance
                except AttributeError:
                    pass
            sum_wakefield.append_to_model(sum_imp)
        self.sum_wakefield = sum_wakefield
        
    def add(self, wakefield_list, wakefiled_postions):
        """
        Add elements to the model.

        Parameters
        ----------
        wakefield_list : list of Wakefield objects
            Wakefields to add to the model.
        wakefiled_postions : list
            Longitudinal positions corresponding to the added Wakfields.
        """
        
        if (wakefield_list is not None) and (wakefiled_postions is not None):
            for wakefield in wakefield_list:
                self.wakefields.append(wakefield)
                
            for position in wakefiled_postions:
                self.positions = np.append(self.positions, position)