# -*- coding: utf-8 -*-
"""
This module defines general classes to describes wakefields, impedances and 
wake functions. Based on impedance library by David Amorim.

@author: David Amorin, Alexis Gamelin
@date: 14/01/2020
"""

import warnings
import re
import pandas as pd
import numpy as np
from scipy.interpolate import interp1d

class ComplexData:
    """
    Define a general data structure for a complex function based on a pandas 
    DataFrame.

    Parameters
    ----------
    variable : list, numpy array
       contains the function variable values

    function : list or numpy array of complex numbers
        contains the values taken by the complex function
    """

    def __init__(self, variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0])):
        self.data = pd.DataFrame({'real': np.real(function),
                                  'imag': np.imag(function)},
                                 index=variable)
        self.data.index.name = 'variable'

    def add(self, structure_to_add, method='zero', interp_kind='cubic'):
        """
        Method to add two structures. If the data don't have the same length,
        different cases are handled.
        
        Parameters
        ----------
        structure_to_add : ComplexData object, int, float or complex.
            structure to be added.
        method : str, optional
            manage how the addition is done, possibilties are: 
            -common: the common range of the index (variable) from the two
                structures is used. The input data are cross-interpolated
                so that the result data contains the points at all initial
                points in the common domain.
            -extrapolate: extrapolate the value of both ComplexData.
            -zero: outside of the common range, the missing values are zero. In 
                the common range, behave as "common" method.
        interp_kind : str, optional
            interpolation method which is passed to pandas and to
            scipy.interpolate.interp1d.
            
        Returns
        -------
        ComplexData 
            Contains the sum of the two inputs.
        """

        # Create first two new DataFrames merging the variable
        # from the two input data

        if isinstance(structure_to_add, (int, float, complex)):
            structure_to_add = ComplexData(
                                variable=self.data.index,
                                function=structure_to_add * np.ones(len(self.data.index)))

        initial_data = pd.concat(
                        [self.data,
                         structure_to_add.data.index.to_frame().set_index('variable')],
                        sort=True)
        initial_data = initial_data[~initial_data.index.duplicated(keep='first')]

        data_to_add = pd.concat(
                        [structure_to_add.data,
                         self.data.index.to_frame().set_index('variable')],
                        sort=True)
        data_to_add = data_to_add[~data_to_add.index.duplicated(keep='first')]

        if method == 'common':
            max_variable = min(structure_to_add.data.index.max(),
                               self.data.index.max())

            min_variable = max(structure_to_add.data.index.min(),
                               self.data.index.min())

            initial_data = initial_data.interpolate(method=interp_kind)
            mask = ((initial_data.index <= max_variable)
                    & (initial_data.index >= min_variable))
            initial_data = initial_data[mask]

            data_to_add = data_to_add.interpolate(method=interp_kind)
            mask = ((data_to_add.index <= max_variable)
                    & (data_to_add.index >= min_variable))
            data_to_add = data_to_add[mask]

            result_structure = ComplexData()
            result_structure.data = initial_data + data_to_add
            return result_structure

        if method == 'extrapolate':
            print('Not there yet')
            return self
        
        if method == 'zero':
            max_variable = min(structure_to_add.data.index.max(),
                               self.data.index.max())

            min_variable = max(structure_to_add.data.index.min(),
                               self.data.index.min())

            mask = ((initial_data.index <= max_variable)
                    & (initial_data.index >= min_variable))
            initial_data[mask] = initial_data[mask].interpolate(method=interp_kind)

            mask = ((data_to_add.index <= max_variable)
                    & (data_to_add.index >= min_variable))
            data_to_add[mask] = data_to_add[mask].interpolate(method=interp_kind)
            
            initial_data.replace(to_replace=np.nan, value=0, inplace=True)
            data_to_add.replace(to_replace=np.nan, value=0, inplace=True)
            
            result_structure = ComplexData()
            result_structure.data = initial_data + data_to_add
            return result_structure

    def __radd__(self, structure_to_add):
        return self.add(structure_to_add, method='zero')

    def __add__(self, structure_to_add):
        return self.add(structure_to_add, method='zero')

    def multiply(self, factor):
        """
        Multiply a data strucure with a float or an int.
        If the multiplication is done with something else, throw a warning.
        """
        if isinstance(factor, (int, float)):
            result_structure = ComplexData()
            result_structure.data = self.data * factor
            return result_structure
        else:
            warnings.warn(('The multiplication factor is not a float '
                           'or an int.'), UserWarning)
            return self

    def __mul__(self, factor):
        return self.multiply(factor)

    def __rmul__(self, factor):
        return self.multiply(factor)
    
    def __call__(self, values, interp_kind="cubic"):
        """
        Interpolation of the data by calling the class to have a function-like
        behaviour.
        
        Parameters
        ----------
        values : list or numpy array of complex, int or float
            values to be interpolated.
        interp_kind : str, optional
            interpolation method which is passed to scipy.interpolate.interp1d.
            
        Returns
        -------
        numpy array 
            Contains the interpolated data.
        """
        real_func = interp1d(x = self.data.index, 
                             y = self.data["real"], kind=interp_kind)
        imag_func = interp1d(x = self.data.index, 
                             y = self.data["imag"], kind=interp_kind)
        return real_func(values) + 1j*imag_func(values)
    
class WakeFunction(ComplexData):
    """
    Define a WakeFunction object based on a ComplexData object.
    """

    def __init__(self,
                 variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0]), wake_type='long'):
        super().__init__(variable, function)
        pass
    
class Impedance(ComplexData):
    """
    Define an Impedance object based on a ComplexData object.
    """

    def __init__(self,
                 variable=np.array([-1e15, 1e15]),
                 function=np.array([0, 0]), impedance_type='long'):
        super().__init__(variable, function)
        self._impedance_type = impedance_type
        self.data.index.name = 'frequency [Hz]'
        self.initialize_impedance_coefficient()


    def initialize_impedance_coefficient(self):
        """
        Define the impedance coefficients and the plane of the impedance that
        are presents in attributes of the class.
        """
        table = self.impedance_name_and_coefficients_table()
        
        try:
            component_coefficients = table[self.impedance_type].to_dict()
        except KeyError:
            print('Impedance type {} does not exist'.format(self.impedance_type))
        
        self.a = component_coefficients["a"]
        self.b = component_coefficients["b"]
        self.c = component_coefficients["c"]
        self.d = component_coefficients["d"]
        self.plane = component_coefficients["plane"]
                    
    def impedance_name_and_coefficients_table(self):
        """
        Return a table associating the human readbale names of an impedance
        component and its associated coefficients and plane.
        """

        component_dict = {
            'long': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'z'},
            'xcst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
            'ycst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
            'xdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
            'ydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'y'},
            'xydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'x'},
            'yxdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
            'xquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'x'},
            'yquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'y'},
            'xyquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'x'},
            'yxquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'y'},
            }

        return pd.DataFrame(component_dict)

    def add(self, structure_to_add, beta_x=1, beta_y=1, method='zero'):
        """
        Method to add two Impedance objects. The two structures are
        compared so that the addition is self-consistent.
        Beta functions can be precised as well.
        """
        if isinstance(structure_to_add, (int, float, complex)):
            result = super().add(structure_to_add, method=method)
        elif isinstance(structure_to_add, Impedance):
            if (self.impedance_type == structure_to_add.impedance_type):
                weight = (beta_x ** self.power_x) * (beta_y ** self.power_y)
                result = super().add(structure_to_add * weight, method=method)
            else:
                warnings.warn(('The two Impedance objects do not have the '
                               'same coordinates or plane or type. '
                               'Returning initial Impedance object.'),
                              UserWarning)
                result = self
                
        impedance_to_return = Impedance(
                                result.data.index,
                                result.data.real.values + 1j*result.data.imag.values,
                                self.impedance_type)    
        return impedance_to_return

    def __radd__(self, structure_to_add):
        return self.add(structure_to_add)

    def __add__(self, structure_to_add):
        return self.add(structure_to_add)

    def multiply(self, factor):
        """
        Multiply a Impedance object with a float or an int.
        If the multiplication is done with something else, throw a warning.
        """
        result = super().multiply(factor)
        impedance_to_return = Impedance(
                                result.data.index,
                                result.data.real.values + 1j*result.data.imag.values,
                                self.impedance_type)    
        return impedance_to_return

    def __mul__(self, factor):
        return self.multiply(factor)

    def __rmul__(self, factor):
        return self.multiply(factor)
        
    @property
    def impedance_type(self):
        return self._impedance_type
    
    @impedance_type.setter
    def impedance_type(self, value):
        self._impedance_type = value
        self.initialize_impedance_coefficient()

    @property
    def power_x(self):
        power_x = self.a/2 + self.c/2.
        if self.plane == 'x':
            power_x += 1./2.
        return power_x

    @property
    def power_y(self):
        power_y = self.b/2. + self.d/2.
        if self.plane == 'y':
            power_y += 1./2.
        return power_y
    
class Wakefield:
    """
    Define a general Wakefield machine element.
    It contains Imepdance or WakeFunction objects which defines the wakefield
    and other informations: beta functions.
    """

    def __init__(self, betax=1, betay=1, structure_list=[]):
        self.betax = betax
        self.betay = betay
        self.list_to_attr(structure_list)

    def append_to_model(self, structure_to_add):
        list_of_attributes = dir(self)
        if isinstance(structure_to_add, Impedance):
            attribute_name = "Z" + structure_to_add.impedance_type
            if attribute_name in list_of_attributes:
                raise ValueError("There is already a component of the type "
                                 "{} in this element.".format(attribute_name))
            else:
                self.__setattr__(attribute_name, structure_to_add)
        elif isinstance(structure_to_add, WakeFunction):
            attribute_name = "W" + structure_to_add.impedance_type
            if attribute_name in list_of_attributes:
                raise ValueError("There is already a component of the type "
                                 "{} in this element.".format(attribute_name))
            else:
                self.__setattr__(attribute_name, structure_to_add)
        else:
            raise ValueError("{} is not an Impedance nor a WakeFunction.".format(structure_to_add))
    
    def list_to_attr(self, structure_list):
        for component in structure_list:
            self.append_to_model(component)
    
    @property
    def list_impedance_components(self):
        """
        Return the list of impedance components for the element,
        based on the attributes names.
        """
        return np.array([comp for comp in dir(self) if re.match(r'[Z]', comp)])
    
    def add(self, element_to_add):
        """
        This function adds two Wakefield objects.
        The different impedance components are weighted and added.
        The result is a Wakefield object with all the components but the
        beta function equal to 1.
        """
        
        if not isinstance(element_to_add, Wakefield):
            raise TypeError("You can only add Wakefield objects to other"
                            "Wakefields objects.")
        
        list_of_summed_impedances = []
        list_of_components_added = []
        
        betax1 = self.betax
        betay1 = self.betay
        betax2 = element_to_add.betax
        betay2 = element_to_add.betay
        
        # We first go through element1 impedance components and add them
        # to the element2 components. If the component is missing in element2,
        # we then simply weight the element1 component and add it to the model.
        for component_name in self.list_impedance_components:
        
            element1_impedance = self.__getattribute__(component_name)
        
            element1_weight = ((betax1 ** element1_impedance.power_x)
                               * (betay1 ** element1_impedance.power_y))
            
            try:
                element2_impedance = element_to_add.__getattribute__(component_name)
                list_of_components_added.append(component_name)
                element12_impedance_sum = element1_weight * element1_impedance
                element12_impedance_sum = element12_impedance_sum.add(element2_impedance, betax2, betay2)
            except:
                element12_impedance_sum = element1_weight * element1_impedance
        
            list_of_summed_impedances.append(element12_impedance_sum)
        
        # We now go through the components which are unique to element2 and
        # add them to the impedance model, with the beta function weighting
        missing_components_list = list(set(element_to_add.list_impedance_components) -
                                       set(list_of_components_added))    
    
        for component_name in missing_components_list:
            
            element2_impedance = element_to_add.__getattribute__(component_name)
            element2_weight = ((betax2 ** element2_impedance.power_x)
                               * (betay2 ** element2_impedance.power_y))
            element12_impedance_sum = element2_weight * element2_impedance
            
            list_of_summed_impedances.append(element12_impedance_sum) 
        
        # Gather everything in an Wakefield object
        sum_wakefield = Wakefield(betax=1., betay=1.,
                              structure_list=list_of_summed_impedances)
            
        return sum_wakefield
        
    def __radd__(self, structure_to_add):
        return self.add(structure_to_add)

    def __add__(self, structure_to_add):
        return self.add(structure_to_add)