# -*- coding: utf-8 -*-
"""
This module defines the most basic elements for tracking, including Element,
an abstract base class which is to be used as mother class to every elements
included in the tracking.

@author: gamelina
@date: 11/03/2020
"""

import numpy as np
import pandas as pd
from abc import ABCMeta, abstractmethod
from functools import wraps
from tracking.particles import Beam
from scipy import signal
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

class Element(metaclass=ABCMeta):
    """
    Abstract Element class used for subclass inheritance to define all kinds 
    of objects which intervene in the tracking.
    """

    @abstractmethod
    def track(self, beam):
        """
        Track a beam object through this Element.
        This method needs to be overloaded in each Element subclass.
        
        Parameters
        ----------
        beam : Beam object
        """
        raise NotImplementedError
        
    @staticmethod
    def parallel(track):
        """
        Defines the decorator @parallel which handle the embarrassingly 
        parallel case which happens when there is no bunch to bunch 
        interaction in the tracking routine.
        
        Adding @Element.parallel allows to write the track method of the 
        Element subclass for a Bunch object instead of a Beam object.
        
        Parameters
        ----------
        track : function, method of an Element subclass
            track method of an Element subclass which takes a Bunch object as
            input
            
        Returns
        -------
        track_wrapper: function, method of an Element subclass
            track method of an Element subclass which takes a Beam object or a
            Bunch object as input
        """
        @wraps(track)
        def track_wrapper(*args, **kwargs):
            if isinstance(args[1], Beam):
                self = args[0]
                beam = args[1]
                if (beam.mpi_switch == True):
                    track(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
                else:
                    for bunch in beam.not_empty:
                        track(self, bunch, *args[2:], **kwargs)
            else:
                self = args[0]
                bunch = args[1]
                track(self, bunch, *args[2:], **kwargs)
        return track_wrapper
        
class LongitudinalMap(Element):
    """
    Longitudinal map for a single turn in the synchrotron.
    
    Parameters
    ----------
    ring : Synchrotron object
    """
    
    def __init__(self, ring):
        self.ring = ring
        
    @Element.parallel
    def track(self, bunch):
        """
        Tracking method for the element.
        No bunch to bunch interaction, so written for Bunch objects and
        @Element.parallel is used to handle Beam objects.
        
        Parameters
        ----------
        bunch : Bunch or Beam object
        """
        bunch["delta"] -= self.ring.U0 / self.ring.E0
        bunch["tau"] -= self.ring.ac * self.ring.T0 * bunch["delta"]
        
class WakePotential(Element):
    """
    Resonator model based wake potential calculation for one turn.
    
    Parameters
    ----------
    ring : Synchrotron object.
    Q_factor : float, optional
        Resonator quality factor. The default value is 1.
    f_res : float, optional
        Resonator resonance frequency in [Hz]. The default value is 10e9 Hz.
    R_shunt : float, optional
        Resonator shunt impedance in [Ohm]. The default value is 100 Ohm.
    n_bin : int, optional
        Number of bins for constructing the longitudinal bunch profile.
        The default is 65.
        
    Attributes
    ----------
    rho : array of shape (n_bin, )
        Bunch charge density profile.
    tau : array of shape (n_bin + time_extra, )
        Time array starting from the head of the bunch until the wake tail 
        called timestop.
        
        The length of time_extra is determined by the last position of the
        bunch time_bunch[-1], timestop, and the mean bin width of the bunch
        profile mean_bin_size as
            len(time_extra) = (timestop - time_bunch[-1]) / mean_bin_size
    W_long : array of shape (n_bin + time_extra, )
        Wakefunction profile.
    W_p : array of shape (n_bin + time_extra, )
        Wake potential profile.
    wp : array of shape (mp_number, )
        Wake potential exerted on each macro-particle.
    
    Methods
    -------
    charge_density(bunch, n_bin=65)
        Calculate bunch charge density.
    plot(self, var, plot_rho=True)
        Plotting wakefunction or wake potential.
    track(bunch)
        Tracking method for the element.
        
    """
    
    def __init__(self, ring, Q_factor=1, f_res=10e9, R_shunt=100, n_bin=65):       
        self.ring = ring
        self.n_bin = n_bin
        
        self.Q_factor = Q_factor
        self.omega_res = 2*np.pi*f_res
        self.R_shunt = R_shunt
        
        if Q_factor >= 0.5:
            self.Q_factor_p = np.sqrt(self.Q_factor**2 - 0.25)
            self.omega_res_p = (self.omega_res*self.Q_factor_p)/self.Q_factor
        else:
            self.Q_factor_pp = np.sqrt(0.25 - self.Q_factor**2)
            self.omega_res_p = (self.omega_res*self.Q_factor_pp)/self.Q_factor
            
    def charge_density(self, bunch, n_bin):
        self.bins = bunch.binning(n_bin=self.n_bin)
        self.bin_data = self.bins[2]
        self.bin_size = self.bins[0].length
        
        self.rho = bunch.charge_per_mp*self.bin_data/ \
            (self.bin_size*bunch.charge)
    
    def init_timestop(self):
        self.timestop = round(np.log(1000)/self.omega_res*2*self.Q_factor, 12)
    
    def time_array(self):        
        time_bunch = self.bins[0].mid
        mean_bin_size = np.mean(self.bin_size)
        time_extra = np.arange(start = time_bunch[-1]+mean_bin_size, 
                               stop = self.timestop, step = mean_bin_size)
        
        self.tau = np.concatenate((time_bunch,time_extra))
     
    def long_wakefunction(self):
        w_list = []
        if self.Q_factor >= 0.5:
            for t in self.tau: 
                if t >= 0:
                    w_long = -(self.omega_res*self.R_shunt/self.Q_factor)*\
                        np.exp(-self.omega_res*t/(2*self.Q_factor))*\
                            (np.cos(self.omega_res_p*t)-\
                             np.sin(self.omega_res_p*t)/(2*self.Q_factor_p))
                else: 
                    w_long = 0
                    
                w_list.append(w_long)
                        
        elif self.Q_factor < 0.5:
            for t in self.tau: 
                if t >= 0:
                    w_long = -(self.omega_res*self.R_shunt/self.Q_factor)*\
                        np.exp(-self.omega_res*t/(2*self.Q_factor))*\
                            (np.cosh(self.omega_res_p*t)-\
                             np.sinh(self.omega_res_p*t)/(2*self.Q_factor_pp))
                else:
                    w_long = 0
                    
                w_list.append(w_long)
                
        self.W_long = np.array(w_list)
                
    def wake_potential(self): 
        self.W_p = signal.convolve(self.W_long*1e-12, self.rho, mode="same") 
    
    def plot(self, var, plot_rho=True):
        """
        Plotting wakefunction or wake potential.

        Parameters
        ----------
        var : {'W_p', 'W_long' }
            If 'W_p', the wake potential is plotted. 
            If 'W_long', the wakefunction is plotted.            
        plot_rho : bool, optional
            Overlay the bunch charge density profile on the plot.
            The default is True.

        Returns
        -------
        fig : Figure
            Figure object with the plot on it.

        """
        
        fig, ax = plt.subplots()
            
        if var == "W_p":
            ax.plot(self.tau*1e12, self.W_p*1e-12)
            
            ax.set_xlabel("$\\tau$ (ps)")
            ax.set_ylabel("W$_p$ (V/pC)")
    
        elif var == "W_long":
            ax.plot(self.tau*1e12, self.W_long*1e-12)
            ax.set_xlabel("$\\tau$ (ps)")
            ax.set_ylabel("W$_{||}$ ($\\Omega$/ps)")
            
        if plot_rho is True:
            rho_array = np.array(self.rho)
            rho_rescaled = rho_array/max(rho_array)*max(self.W_p)
            
            ax.plot(self.bins[0].mid*1e12, rho_rescaled*1e-12)
            
        else:
            pass
            
        return fig
    
    def check_wake_tail(self):
        """
        Checking whether the full wakefunction is obtained by the calculated
        initial timestop.

        """
        
        ratio = np.abs(min(self.W_long) / self.W_long[-6:-1])
        while any(ratio < 1000):
            # Initial timestop is too short. 
            # Extending timestop by 50 ps and recalculating."
            self.timestop += 50e-12      
            self.time_array()
            self.long_wakefunction()
            ratio = np.abs(min(self.W_long) / self.W_long[-6:-1])
        
    @Element.parallel
    def track(self, bunch):
        """
        Tracking method for the element.
        No bunch to bunch interaction, so written for Bunch objects and
        @Element.parallel is used to handle Beam objects.

        Parameters
        ----------
        bunch : Bunch or Beam object.
        
        """

        self.charge_density(bunch, n_bin = self.n_bin)
        self.init_timestop()
        self.time_array()
        self.long_wakefunction()
        self.check_wake_tail()
        self.wake_potential()
        
        f = interp1d(self.tau, self.W_p, fill_value = 0, bounds_error = False)
        self.wp = f(bunch["tau"])
        
        bunch["delta"] += self.wp * bunch.charge / self.ring.E0
        
class SynchrotronRadiation(Element):
    """
    Element to handle synchrotron radiation, radiation damping and quantum 
    excitation, for a single turn in the synchrotron.
    
    Parameters
    ----------
    ring : Synchrotron object
    switch : bool array of shape (3,), optional
        allow to choose on which plane the synchrotron radiation is active
    """
    
    def __init__(self, ring, switch = np.ones((3,), dtype=bool)):
        self.ring = ring
        self.switch = switch
        
    @Element.parallel        
    def track(self, bunch):
        """
        Tracking method for the element.
        No bunch to bunch interaction, so written for Bunch objects and
        @Element.parallel is used to handle Beam objects.
        
        Parameters
        ----------
        bunch : Bunch or Beam object
        """
        if (self.switch[0] == True):
            rand = np.random.normal(size=len(bunch))
            bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
                 2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
            
        if (self.switch[1] == True):
            rand = np.random.normal(size=(len(bunch),2))
            bunch["x"] += self.ring.sigma[0]*(2*self.ring.T0/self.ring.tau[0])**0.5*rand[:,0]
            bunch["xp"] = (1 + bunch["delta"])/(1 + bunch["delta"] + bunch.energy_change)*bunch["xp"]
            bunch["xp"] += self.ring.sigma[1]*(2*self.ring.T0/self.ring.tau[0])**0.5*rand[:,1]
        
        if (self.switch[2] == True):
            rand = np.random.normal(size=(len(bunch),2))
            bunch["y"] += self.ring.sigma[2]*(2*self.ring.T0/self.ring.tau[1])**0.5*rand[:,0]
            bunch["yp"] = (1 + bunch["delta"])/(1 + bunch["delta"] + bunch.energy_change)*bunch["yp"]
            bunch["yp"] += self.ring.sigma[3]*(2*self.ring.T0/self.ring.tau[1])**0.5*rand[:,1]
        
        # Reset energy change to 0 for next turn
        bunch.energy_change = 0
        
class TransverseMap(Element):
    """
    Transverse map for a single turn in the synchrotron.
    
    Parameters
    ----------
    ring : Synchrotron object
    """
    
    def __init__(self, ring):
        self.ring = ring
        self.alpha = self.ring.optics.local_alpha
        self.beta = self.ring.optics.local_beta
        self.gamma = self.ring.optics.local_gamma
        self.dispersion = self.ring.optics.local_dispersion
        self.phase_advance = self.ring.tune[0:2]*2*np.pi
    
    @Element.parallel    
    def track(self, bunch):
        """
        Tracking method for the element.
        No bunch to bunch interaction, so written for Bunch objects and
        @Element.parallel is used to handle Beam objects.
        
        Parameters
        ----------
        bunch : Bunch or Beam object
        """

        # Compute phase adcence which depends on energy via chromaticity
        phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"])
        phase_advance_y = self.phase_advance[1]*(1+self.ring.chro[1]*bunch["delta"])
        
        # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
        matrix = np.zeros((6,6,len(bunch)))
        
        # Horizontal
        matrix[0,0,:] = np.cos(phase_advance_x) + self.alpha[0]*np.sin(phase_advance_x)
        matrix[0,1,:] = self.beta[0]*np.sin(phase_advance_x)
        matrix[0,2,:] = self.dispersion[0]
        matrix[1,0,:] = -1*self.gamma[0]*np.sin(phase_advance_x)
        matrix[1,1,:] = np.cos(phase_advance_x) - self.alpha[0]*np.sin(phase_advance_x)
        matrix[1,2,:] = self.dispersion[1]
        matrix[2,2,:] = 1
        
        # Vertical
        matrix[3,3,:] = np.cos(phase_advance_y) + self.alpha[1]*np.sin(phase_advance_y)
        matrix[3,4,:] = self.beta[1]*np.sin(phase_advance_y)
        matrix[3,5,:] = self.dispersion[2]
        matrix[4,3,:] = -1*self.gamma[1]*np.sin(phase_advance_y)
        matrix[4,4,:] = np.cos(phase_advance_y) - self.alpha[1]*np.sin(phase_advance_y)
        matrix[4,5,:] = self.dispersion[3]
        matrix[5,5,:] = 1
        
        x = matrix[0,0,:]*bunch["x"] + matrix[0,1,:]*bunch["xp"] + matrix[0,2,:]*bunch["delta"]
        xp = matrix[1,0,:]*bunch["x"] + matrix[1,1,:]*bunch["xp"] + matrix[1,2,:]*bunch["delta"]
        y =  matrix[3,3,:]*bunch["y"] + matrix[3,4,:]*bunch["yp"] + matrix[3,5,:]*bunch["delta"]
        yp = matrix[4,3,:]*bunch["y"] + matrix[4,4,:]*bunch["yp"] + matrix[4,5,:]*bunch["delta"]
        
        bunch["x"] = x
        bunch["xp"] = xp
        bunch["y"] = y
        bunch["yp"] = yp