# -*- coding: utf-8 -*-
"""
Beam and bunch elements

@author: Alexis Gamelin
@date: 17/01/2020
"""

import numpy as np
import pandas as pd

class Bunch:
    """
    Define a bunch
    
    Parameters
    ----------
    ring : Synchrotron object
    mp_number : float, optional
        Macro-particle number
    current : float, optional
        Bunch current in [A]
    alive : bool, optional
        If False, the bunch is defined as empty 
        
    Attributes
    ----------
    mp_number : int
        Macro-particle number
    charge : float
        Bunch charge in [C]
    charge_per_mp : float
        Charge per macro-particle in [C]
    particle_number : int
        Number of particles in the bunch
    current : float
        Bunch current in [A]
        
    Methods
    -------
    init_gaussian(cov=None, mean=None, **kwargs)
        Initialize bunch particles with 6D gaussian phase space.
    """
    
    def __init__(self, ring, mp_number=1e3, current=1e-3, alive=True):
        
        self.ring = ring
        if not alive:
            mp_number = 1
            current = 0
        self._mp_number = int(mp_number)
        
        particles = {"x":np.zeros((self.mp_number,)),
                     "xp":np.zeros((self.mp_number,)),
                     "y":np.zeros((self.mp_number,)),
                     "yp":np.zeros((self.mp_number,)),
                     "tau":np.zeros((self.mp_number,)),
                     "delta":np.zeros((self.mp_number,)),
                     }
        self.particles = pd.DataFrame(particles)
        self.alive = pd.Series(np.ones((self.mp_number,),dtype=bool))
        self.current = current
        if not alive:
            self.alive = pd.Series(np.zeros((self.mp_number,),dtype=bool))            
        
    def __len__(self):
        """Return the number of alive particles"""
        return len(self[:])
        
    def __getitem__(self, label):
        """Return the columns label for alive particles"""
        return self.particles.loc[self.alive, label]
    
    def __setitem__(self, label, value):
        """Set value to the columns label for alive particles"""
        self.particles.loc[self.alive, label] = value
    
    def __iter__(self):
        """Iterate over labels"""
        return self[:].__iter__()
    
    def __repr__(self):
        """Return representation of alive particles"""
        return f'{self[:]!r}'
        
    @property
    def mp_number(self):
        """Macro-particle number"""
        return self._mp_number
    
    @mp_number.setter
    def mp_number(self, value):
        self._mp_number = int(value)
        self.__init__(self.ring, value, self.charge)
        
    @property
    def charge_per_mp(self):
        """Charge per macro-particle [C]"""
        return self._charge_per_mp
    
    @charge_per_mp.setter
    def charge_per_mp(self, value):
        self._charge_per_mp = value
        
    @property
    def charge(self):
        """Bunch charge in [C]"""
        return self.__len__()*self.charge_per_mp
    
    @charge.setter
    def charge(self, value):
        self.charge_per_mp = value / self.__len__()
    
    @property
    def particle_number(self):
        """Particle number"""
        return int(self.charge / np.abs(self.ring.particle.charge))
    
    @particle_number.setter
    def particle_number(self, value):
        self.charge_per_mp = value * self.ring.particle.charge / self.__len__()
        
    @property
    def current(self):
        """Bunch current [A]"""
        return self.charge / self.ring.T0
    
    @current.setter
    def current(self, value):
        self.charge_per_mp = value * self.ring.T0 / self.__len__()
    
    @property    
    def mean(self):
        """
        Compute the mean position of alive particles for each 
        coordinates.
        
        Returns
        -------
        mean : numpy array
            mean position of alive particles
        """
        mean = [[self[name].mean()] for name in self]
        return np.array(mean)
    
    @property
    def std(self):
        """
        Compute the standard deviation of the position of alive 
        particles for each coordinates.
        
        Returns
        -------
        std : numpy array
            standard deviation of the position of alive particles
        """
        std = [[self[name].std()] for name in self]
        return np.array(std)
    
    @property    
    def emit(self):
        """
        Compute the bunch emittance for each plane [1].
        Correct definition of long emit ?
        
        Returns
        -------
        emit : numpy array
            bunch emittance
        
        References
        ----------
        [1] Wiedemann, H. (2015). Particle accelerator physics. 4th 
        edition. Springer, Eq.(8.39) of p224.
        """
        emitX = (np.mean(self['x']**2)*np.mean(self['xp']**2) - 
                 np.mean(self['x']*self['xp'])**2)**(0.5)
        emitY = (np.mean(self['y']**2)*np.mean(self['yp']**2) - 
                 np.mean(self['y']*self['yp'])**2)**(0.5)
        emitS = (np.mean(self['tau']**2)*np.mean(self['delta']**2) - 
                 np.mean(self['tau']*self['delta'])**2)**(0.5)
        return np.array([emitX, emitY, emitS])
        
    def init_gaussian(self, cov=None, mean=None, **kwargs):
        """
        Initialize bunch particles with 6D gaussian phase space.
        Covariance matrix is taken from [1].
        
        Parameters
        ----------
        cov : (6,6) array, optional
            Covariance matrix of the bunch distribution
        mean : (6,) array, optional
            Mean of the bunch distribution
        
        References
        ----------
        [1] Wiedemann, H. (2015). Particle accelerator physics. 4th 
        edition. Springer, Eq.(8.38) of p223.

        """
        if mean is None:
            mean = np.zeros((6,))
        
        if cov is None:
            sigma_0 = kwargs.get("sigma_0", self.ring.sigma_0)
            sigma_delta = kwargs.get("sigma_delta", self.ring.sigma_delta)
            optics = kwargs.get("optics", self.ring.mean_optics)
            
            cov = np.zeros((6,6))
            cov[0,0] = self.ring.emit[0]*optics.beta[0]
            cov[1,1] = self.ring.emit[0]*optics.gamma[0]
            cov[0,1] = -1*self.ring.emit[0]*optics.alpha[0]
            cov[1,0] = -1*self.ring.emit[0]*optics.alpha[0]
            cov[2,2] = self.ring.emit[1]*optics.beta[1]
            cov[3,3] = self.ring.emit[1]*optics.gamma[1]
            cov[2,3] = -1*self.ring.emit[1]*optics.alpha[1]
            cov[3,2] = -1*self.ring.emit[1]*optics.alpha[1]
            cov[4,4] = sigma_0**2
            cov[5,5] = sigma_delta**2
            
        values = np.random.multivariate_normal(mean, cov, size=self.mp_number)
        self.particles["x"] = values[:,0]
        self.particles["xp"] = values[:,1]
        self.particles["y"] = values[:,2]
        self.particles["yp"] = values[:,3]
        self.particles["tau"] = values[:,4]
        self.particles["delta"] = values[:,5]
        
class Beam:
    """
    Define a Beam object composed of several Bunch objects. 
    
    Parameters
    ----------
    ring : Synchrotron object
    bunch_list : list of Bunch object, optional

    Attributes
    ----------
    current : float
        Total beam current in [A]
    charge : float
        Total bunch charge in [C]
    particle_number : int
        Total number of particle in the beam
    filling_pattern : array of bool
        Filling pattern of the beam
        
    Methods
    ------
    init_beam(filling_pattern, current_per_bunch=1e-3, mp_per_bunch=1e3)
        Initialize beam with a given filling pattern and marco-particle number 
        per bunch. Then initialize the different bunches with a 6D gaussian
        phase space.
    """
    
    def __init__(self, ring, bunch_list=None):
        self.ring = ring
        
        if bunch_list is None:
            self.init_beam(np.zeros((self.ring.h,1),dtype=bool))
        else:
            if (len(bunch_list) != self.ring.h):
                raise ValueError(("The length of the bunch list is {} ".format(len(bunch_list)) + 
                                  "but should be {}".format(self.ring.h)))
            self.bunch_list = bunch_list
            
    def __len__(self):
        """Return the number of (not empty) bunches"""
        length = 0
        for bunch in self.not_empty:
            length += 1
        return length        
    
    def __getitem__(self, i):
        """Return the bunch number i"""
        return self.bunch_list.__getitem__(i)
    
    def __setitem__(self, i, value):
        """Set value to the bunch number i"""
        self.bunch_list.__setitem__(i, value)
    
    def __iter__(self):
        """Iterate over all bunches"""
        return self.bunch_list.__iter__()
   
    @property             
    def not_empty(self):
        """Return a generator to iterate over not empty bunches"""
        for bunch in self:
            if bunch.current == 0:
                pass
            else:
                yield bunch
        
    def init_beam(self, filling_pattern, current_per_bunch=1e-3, 
                  mp_per_bunch=1e3):
        """
        Initialize beam with a given filling pattern and marco-particle number 
        per bunch. Then initialize the different bunches with a 6D gaussian
        phase space.
        
        If the filling pattern is an array of bool then the current per bunch 
        is uniform, else the filling pattern can be an array with the current
        in each bunch.
        
        Parameters
        ----------
        filling_pattern : numpy array or list of length ring.h
            Filling pattern of the beam, can be a list or an array of bool, 
            then current_per_bunch is used. Or can be an array with the current
            in each bunch.
        current_per_bunch : float, optional
            Current per bunch in [A]
        mp_per_bunch : float, optional
            Macro-particle number per bunch
        """
        
        if (len(filling_pattern) != self.ring.h):
            raise ValueError(("The length of filling pattern is {} ".format(len(filling_pattern)) + 
                              "but should be {}".format(self.ring.h)))
        
        filling_pattern = np.array(filling_pattern)
        bunch_list = []
        if filling_pattern.dtype == np.dtype("bool"):
            for value in filling_pattern:
                if value == True:
                    bunch_list.append(Bunch(self.ring, mp_per_bunch, current_per_bunch))
                elif value == False:
                    bunch_list.append(Bunch(self.ring, alive=False))
        elif filling_pattern.dtype == np.dtype("float64"):
            for current in filling_pattern:
                if current != 0:
                    bunch_list.append(Bunch(self.ring, mp_per_bunch, current))
                elif current == 0:
                    bunch_list.append(Bunch(self.ring, alive=False))
        else:
            raise TypeError("{} should be bool or float64".format(filling_pattern.dtype))
                
        self.bunch_list = bunch_list
        
        for bunch in self.not_empty:
            bunch.init_gaussian()
    
    @property
    def filling_pattern(self):
        """Return an array with the filling pattern of the beam as bool"""
        filling_pattern = []
        for bunch in self:
            if filling_pattern != 0:
                filling_pattern.append(True)
            else:
                filling_pattern.append(False)
        return np.array(filling_pattern)
        
    @property
    def bunch_current(self):
        """Return an array with the current in each bunch in [A]"""
        bunch_current = [bunch.current for bunch in self]
        return np.array(bunch_current)
    
    @property
    def bunch_charge(self):
        """Return an array with the charge in each bunch in [C]"""
        bunch_charge = [bunch.charge for bunch in self]
        return np.array(bunch_charge)
    
    @property
    def bunch_particle(self):
        """Return an array with the particle number in each bunch"""
        bunch_particle = [bunch.particle_number for bunch in self]
        return np.array(bunch_particle)
    
    @property
    def current(self):
        """Total beam current in [A]"""
        return np.sum(self.bunch_current)
    
    @property
    def charge(self):
        """Total beam charge in [C]"""
        return np.sum(self.bunch_charge)
    
    @property
    def particle_number(self):
        """Total number of particles in the beam"""
        return np.sum(self.bunch_particle)
    
    @property
    def bunch_mean(self):
        """Return an array with the mean position of alive particles for each
        bunches"""
        bunch_mean = np.zeros((6,self.ring.h))
        for index, bunch in enumerate(self):
            bunch_mean[:,index] = np.squeeze(bunch.mean)
        return bunch_mean
    
    @property
    def bunch_std(self):
        """Return an array with the standard deviation of the position of alive 
        particles for each bunches"""
        bunch_std = np.zeros((6,self.ring.h))
        for index, bunch in enumerate(self):
            bunch_std[:,index] = np.squeeze(bunch.std)
        return bunch_std
    
    @property
    def bunch_emit(self):
        """Return an array with the bunch emittance of alive particles for each
        bunches and each plane"""
        bunch_emit = np.zeros((3,self.ring.h))
        for index, bunch in enumerate(self):
            bunch_emit[:,index] = np.squeeze(bunch.emit)
        return bunch_emit