# -*- coding: utf-8 -*-
"""
Module where the synchrotron class is defined.

@author: Alexis Gamelin
@date: 15/01/2020
"""

import numpy as np
from scipy.constants import c, e
        
class Synchrotron:
    """
    Synchrotron class to store main properties.
    
    Optional parameters are optional only if the Optics object passed to the 
    class uses a loaded lattice.
    
    Parameters
    ----------
    h : int
        Harmonic number of the accelerator.
    optics : Optics object
        Object where the optic functions are stored.
    particle : Particle object
        Particle which is accelerated.
    tau : array of shape (3,)
        Horizontal, vertical and longitudinal damping times in [s].
    sigma_delta : float
        Equilibrium energy spread.
    sigma_0 : float
        Natural bunch length in [s].
    emit : array of shape (2,)
        Horizontal and vertical equilibrium emittance in [m.rad].
    L : float, optional
        Ring circumference in [m].
    E0 : float, optional
        Nominal (total) energy of the ring in [eV].
    ac : float, optional
        Momentum compaction factor.
    tune : array of shape (2,), optional
        Horizontal and vertical tunes.
    chro : array of shape (2,), optional
        Horizontal and vertical (non-normalized) chromaticities.
    U0 : float, optional
        Energy loss per turn in [eV].
        
    Attributes
    ----------
    T0 : float
        Revolution time in [s].
    f0 : float
        Revolution frequency in [Hz].
    omega0 : float
        Angular revolution frequency in [Hz.rad]
    f1 : float
        Fundamental RF frequency in [Hz].
    omega1 : float
        Fundamental RF angular frequency in [Hz.rad].
    k1 : float
        Fundamental RF wave number in [m**-1].
    gamma : float
        Relativistic Lorentz gamma.
    beta : float
        Relativistic Lorentz beta.
    eta : float
        Momentum compaction.
    sigma : array of shape (4,)
        RMS beam size at equilibrium in [m].
    """
    def __init__(self, h, optics, particle, **kwargs):
        self._h = h
        self.particle = particle
        self.optics = optics
        
        if self.optics.use_local_values == False:
            self.L = kwargs.get('L', self.optics.lattice.circumference)
            self.E0 = kwargs.get('E0', self.optics.lattice.energy)
            self.ac = kwargs.get('ac', self.optics.ac)
            self.tune = kwargs.get('tune', self.optics.tune)
            self.chro = kwargs.get('chro', self.optics.chro)
            self.U0 = kwargs.get('U0', self.optics.lattice.energy_loss)
        else:
            self.L = kwargs.get('L') # Ring circumference [m]
            self.E0 = kwargs.get('E0') # Nominal (total) energy of the ring [eV]
            self.ac = kwargs.get('ac') # Momentum compaction factor
            self.tune = kwargs.get('tune') # X/Y/S tunes
            self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
            self.U0 = kwargs.get('U0') # Energy loss per turn [eV]
            
        self.tau = kwargs.get('tau') # X/Y/S damping times [s]
        self.sigma_delta = kwargs.get('sigma_delta') # Equilibrium energy spread
        self.sigma_0 = kwargs.get('sigma_0') # Natural bunch length [s]
        self.emit = kwargs.get('emit') # X/Y emittances in [m.rad]
                
    @property
    def h(self):
        """Harmonic number"""
        return self._h
    
    @h.setter
    def h(self, value):
        self._h = value
        self.L = self.L  # call setter
        
    @property
    def L(self):
        """Ring circumference [m]"""
        return self._L
    
    @L.setter
    def L(self,value):
        self._L = value
        self._T0 = self.L/c
        self._f0 = 1/self.T0
        self._omega0 = 2*np.pi*self.f0
        self._f1 = self.h*self.f0
        self._omega1 = 2*np.pi*self.f1
        self._k1 = self.omega1/c
        
    @property
    def T0(self):
        """Revolution time [s]"""
        return self._T0
    
    @T0.setter
    def T0(self, value):
        self.L = c*value
        
    @property
    def f0(self):
        """Revolution frequency [Hz]"""
        return self._f0
    
    @f0.setter
    def f0(self,value):
        self.L = c/value
        
    @property
    def omega0(self):
        """Angular revolution frequency [Hz rad]"""
        return self._omega0
    
    @omega0.setter
    def omega0(self,value):
        self.L = 2*np.pi*c/value
        
    @property
    def f1(self):
        """Fundamental RF frequency [Hz]"""
        return self._f1
    
    @f1.setter
    def f1(self,value):
        self.L = self.h*c/value
        
    @property
    def omega1(self):
        """Fundamental RF angular frequency[Hz rad]"""
        return self._omega1
    
    @omega1.setter
    def omega1(self,value):
        self.L = 2*np.pi*self.h*c/value
        
    @property
    def k1(self):
        """Fundamental RF wave number [m**-1]"""
        return self._k1
    
    @k1.setter
    def k1(self,value):
        self.L = 2*np.pi*self.h/value
    
    @property
    def gamma(self):
        """Relativistic gamma"""
        return self._gamma

    @gamma.setter
    def gamma(self, value):
        self._gamma = value
        self._beta = np.sqrt(1 - self.gamma**-2)
        self._E0 = self.gamma*self.particle.mass*c**2/e

    @property
    def beta(self):
        """Relativistic beta"""
        return self._beta

    @beta.setter
    def beta(self, value):
        self.gamma = 1/np.sqrt(1-value**2)
        
    @property
    def E0(self):
        """Nominal (total) energy of the ring [eV]"""
        return self._E0
    
    @E0.setter
    def E0(self, value):
        self.gamma = value/(self.particle.mass*c**2/e)

    @property
    def eta(self):
        """Momentum compaction"""
        return self.ac - 1/(self.gamma**2)
    
    @property
    def sigma(self):
        """RMS beam size at equilibrium"""
        sigma = np.zeros((4,))
        sigma[0] = (self.emit[0]*self.optics.local_beta[0] +
                    self.optics.local_dispersion[0]**2*self.sigma_delta)**0.5
        sigma[1] = (self.emit[0]*self.optics.local_alpha[0] +
                    self.optics.local_dispersion[1]**2*self.sigma_delta)**0.5
        sigma[2] = (self.emit[1]*self.optics.local_beta[1] +
                    self.optics.local_dispersion[2]**2*self.sigma_delta)**0.5
        sigma[3] = (self.emit[1]*self.optics.local_alpha[1] +
                    self.optics.local_dispersion[3]**2*self.sigma_delta)**0.5
        return sigma