# -*- coding: utf-8 -*-
"""
This module defines utilities functions, helping to deals with the 
collective_effects module.

@author: Alexis Gamelin
@date: 28/09/2020
"""

import pandas as pd
import numpy as np
from pathlib import Path
from scipy.interpolate import interp1d
from scipy.constants import c
from mbtrack2.collective_effects.wakefield import Impedance, WakeFunction, WakeField
from pathlib import Path


def read_CST(file, impedance_type='long', divide_by=None):
    """
    Read CST file format into an Impedance object.
    
    Parameters
    ----------
    file : str
        Path to the file to read.
    impedance_type : str, optional
        Type of the Impedance object.
    divide_by : float, optional
        Divide the impedance by a value. Mainly used to normalize transverse 
        impedance by displacement.
        
    Returns
    -------
    result : Impedance object
        Data from file.
    """
    df = pd.read_csv(file, comment="#", header = None, sep = "\t", 
                    names = ["Frequency","Real","Imaginary"])
    df["Frequency"] = df["Frequency"]*1e9 
    if divide_by is not None:
        df["Real"] = df["Real"]/divide_by
        df["Imaginary"] = df["Imaginary"]/divide_by
    if impedance_type == "long":
        df["Real"] = np.abs(df["Real"])
    df.set_index("Frequency", inplace = True)
    result = Impedance(variable = df.index,
                       function = df["Real"] + 1j*df["Imaginary"],
                       impedance_type=impedance_type)
    return result

def read_IW2D(file, file_type='Zlong'):
    """
    Read IW2D file format into an Impedance object or a WakeField object.
    
    Parameters
    ----------
    file : str
        Path to the file to read.
    file_type : str, optional
        Type of the Impedance or WakeField object.
        
    Returns
    -------
    result : Impedance or WakeField object
        Data from file.
    """
    if file_type[0] == "Z":
        df = pd.read_csv(file, sep = ' ', header = None, 
                         names = ["Frequency","Real","Imaginary"], skiprows=1)
        df.set_index("Frequency", inplace = True)
        df = df[df["Real"].notna()]
        df = df[df["Imaginary"].notna()]
        result = Impedance(variable = df.index,
                           function = df["Real"] + 1j*df["Imaginary"],
                           impedance_type=file_type[1:])
    elif file_type[0] == "W":
        df = pd.read_csv(file, sep = ' ', header = None, 
                         names = ["Distance","Wake"], skiprows=1)
        df["Time"] = df["Distance"] / c
        df.set_index("Time", inplace = True)
        if np.any(df.isna()):
            index = df.isna().values
            df = df.interpolate()
            print("Nan values have been interpolated to:")
            print(df[index])
        # if file_type == "Wlong":
        #     df["Wake"] = df["Wake"]*-1
        result = WakeFunction(variable = df.index,
                           function = df["Wake"],
                           wake_type=file_type[1:])
    else:
        raise ValueError("file_type should begin by Z or W.")
    return result

def read_IW2D_folder(folder, suffix, select="WZ"):
    """
    Read IW2D results into a WakeField object.
    
    Parameters
    ----------
    file : str
        Path to the file to read.
    suffix : str
        End of the name of each files. For example, in "Zlong_test.dat" the
        suffix should be "_test.dat".
    select : str, optional
        Select which object to load. "W" for WakeFunction, "Z" for Impedance 
        and "WZ" or "ZW" for both.
        
    Returns
    -------
    result : WakeField object
        WakeField object with Impedance and WakeFunction objects from the 
        different files.
    """
    if (select == "WZ") or (select == "ZW"):
        types = {"W" : WakeFunction,
                 "Z" : Impedance}
    elif (select == "W"):
        types = {"W" : WakeFunction}
    elif (select == "Z"):
        types = {"Z" : Impedance}
    else:
        raise ValueError("select should be W, Z or WZ.")
        
    components = ["long", "xdip", "ydip", "xquad", "yquad"]
    
    data_folder = Path(folder)
    
    list_for_wakefield = []
    for key, item in types.items():
        for component in components:
            name = data_folder / (key + component + suffix)
            res = read_IW2D(file=name, file_type=key + component)
            list_for_wakefield.append(res)
            
    wake = WakeField(list_for_wakefield)
    
    return wake

def spectral_density(frequency, sigma, m = 1, mode="Hermite"):
    """
    Compute the spectral density of different modes for various values of the
    head-tail mode number, based on Table 1 p238 of [1].
    
    Parameters
    ----------
    frequency : list or numpy array
        sample points of the spectral density in [Hz]
    sigma : float
        RMS bunch length in [s]
    m : int, optional
        head-tail (or azimutal/synchrotron) mode number
    mode: str, optional
        type of the mode taken into account for the computation:
        -"Hermite" modes for Gaussian bunches

    Returns
    -------
    numpy array
    
    References
    ----------
    [1] : Handbook of accelerator physics and engineering, 3rd printing.
    """
    
    if mode == "Hermite":
        return (2*np.pi*frequency*sigma)**(2*m)*np.exp(
                -1*(2*np.pi*frequency*sigma)**2)
    else:
        raise NotImplementedError("Not implemanted yet.")
        
def gaussian_bunch_spectrum(frequency, sigma): 
    """
    Compute a Gaussian bunch spectrum [1].

    Parameters
    ----------
    frequency : array
        sample points of the beam spectrum in [Hz].
    sigma : float
        RMS bunch length in [s].

    Returns
    -------
    bunch_spectrum : array
        Bunch spectrum sampled at points given in frequency.
        
    References
    ----------
    [1] : Gamelin, A. (2018). Collective effects in a transient microbunching 
    regime and ion cloud mitigation in ThomX. p86, Eq. 4.19
    """
    return np.exp(-1/2*(2*np.pi*frequency)**2*sigma**2)

def gaussian_bunch(time, sigma): 
    """
    Compute a Gaussian bunch profile.

    Parameters
    ----------
    time : array
        sample points of the bunch profile in [s].
    sigma : float
        RMS bunch length in [s].

    Returns
    -------
    bunch_profile : array
        Bunch profile in [s**-1] sampled at points given in time.
    """
    return np.exp(-1/2*(time**2/sigma**2))/(sigma*np.sqrt(2*np.pi))
        
        
def beam_spectrum(frequency, M, bunch_spacing, sigma=None, 
                  bunch_spectrum=None): 
    """
    Compute the beam spectrum assuming constant spacing between bunches [1].

    Parameters
    ----------
    frequency : list or numpy array
        sample points of the beam spectrum in [Hz].
    M : int
        Number of bunches.
    bunch_spacing : float
        Time between two bunches in [s].
    sigma : float, optional
        If bunch_spectrum is None then a Gaussian bunch with sigma RMS bunch 
        length in [s] is assumed.
    bunch_spectrum : array, optional
        Bunch spectrum sampled at points given in frequency.

    Returns
    -------
    beam_spectrum : array

    References
    ----------
    [1] Rumolo, G - Beam Instabilities - CAS - CERN Accelerator School: 
        Advanced Accelerator Physics Course - 2014, Eq. 9
    """
    
    if bunch_spectrum is None:
        bunch_spectrum = gaussian_bunch_spectrum(frequency, sigma)
        
    beam_spectrum = (bunch_spectrum * np.exp(1j*np.pi*frequency * 
                                             bunch_spacing*(M-1)) * 
                     np.sin(M*np.pi*frequency*bunch_spacing) / 
                     np.sin(np.pi*frequency*bunch_spacing))
    
    return beam_spectrum
    
    
def effective_impedance(ring, imp, m, mu, sigma, M, tuneS, xi=None, 
                        mode="Hermite"):
    """
    Compute the effective (longitudinal or transverse) impedance. 
    Formulas from Eq. (1) and (2) p238 of [1].
    
    Parameters
    ----------
    ring : Synchrotron object
    imp : Impedance object
    mu : int
        coupled bunch mode number, goes from 0 to (M-1) where M is the
        number of bunches
    m : int
        head-tail (or azimutal/synchrotron) mode number
    sigma : float
        RMS bunch length in [s]
    M : int
        Number of bunches.
    tuneS : float
        Synchrotron tune.
    xi : float, optional
        (non-normalized) chromaticity
    mode: str, optional
        type of the mode taken into account for the computation:
        -"Hermite" modes for Gaussian bunches

    Returns
    -------
    Zeff : float 
        effective impedance in [ohm] or in [ohm/m] depanding on the impedance
        type.
        
    References
    ----------
    [1] : Handbook of accelerator physics and engineering, 3rd printing.
    """
    
    if not isinstance(imp, Impedance):
        raise TypeError("{} should be an Impedance object.".format(imp))
        
    fmin = imp.data.index.min()
    fmax = imp.data.index.max()
    if fmin > 0:
        double_sided_impedance(imp)
        
    if mode == "Hermite":
        def h(f):
            return spectral_density(frequency=f, sigma=sigma, m=m,
                                    mode="Hermite")
    else:
        raise NotImplementedError("Not implemanted yet.")
    
    pmax = fmax/(ring.f0 * M) - 1
    pmin = fmin/(ring.f0 * M) + 1
    
    p = np.arange(pmin,pmax+1)
    
    if imp.impedance_type == "long":
        fp = ring.f0*(p*M + mu + m*tuneS)
        fp = fp[np.nonzero(fp)] # Avoid division by 0
        num = np.sum( imp(fp) * h(fp) / (fp*2*np.pi) )
        den = np.sum( h(fp) )
        Zeff = num/den
        
    elif imp.impedance_type == "xdip" or imp.impedance_type == "ydip":
        if imp.impedance_type == "xdip":
            tuneXY = ring.tune[0]
            if xi is None :
                xi = ring.chro[0]
        elif imp.impedance_type == "ydip":
            tuneXY = ring.tune[1]
            if xi is None:
                xi = ring.chro[1]
        fp = ring.f0*(p*M + mu + tuneXY + m*tuneS)
        f_xi = xi/ring.eta*ring.f0
        num = np.sum( imp(fp) * h(fp - f_xi) )
        den = np.sum( h(fp - f_xi) )
        Zeff = num/den
    else:
        raise TypeError("Effective impedance is only defined for long, xdip"
                        " and ydip impedance type.")
        
    return Zeff


def yokoya_elliptic(x_radius , y_radius):
    """
    Compute Yokoya factors for an elliptic beam pipe.
    Function adapted from N. Mounet IW2D.

    Parameters
    ----------
    x_radius : float
        Horizontal semi-axis of the ellipse in [m].
    y_radius : float
        Vertical semi-axis of the ellipse in [m].

    Returns
    -------
    yoklong : float
        Yokoya factor for the longitudinal impedance.
    yokxdip : float
        Yokoya factor for the dipolar horizontal impedance.
    yokydip : float
        Yokoya factor for the dipolar vertical impedance.
    yokxquad : float
        Yokoya factor for the quadrupolar horizontal impedance.
    yokyquad : float
        Yokoya factor for the quadrupolar vertical impedance.
    """
    if y_radius < x_radius:
        small_semiaxis = y_radius
        large_semiaxis = x_radius
    else:
        small_semiaxis = x_radius
        large_semiaxis = y_radius
        
    path_to_file = Path(__file__).parent
    file = path_to_file / "data" / "Yokoya_elliptic_from_Elias_USPAS.csv"

    # read Yokoya factors interpolation file
    # BEWARE: columns are ratio, dipy, dipx, quady, quadx
    yokoya_file = pd.read_csv(file)
    ratio_col = yokoya_file["x"]
    # compute semi-axes ratio (first column of this file)
    ratio = (large_semiaxis - small_semiaxis)/(large_semiaxis + small_semiaxis)

    # interpolate Yokoya file at the correct ratio
    yoklong = 1
    
    if y_radius < x_radius:
        yokydip = np.interp(ratio, ratio_col, yokoya_file["dipy"])
        yokxdip = np.interp(ratio, ratio_col, yokoya_file["dipx"])
        yokyquad = np.interp(ratio, ratio_col, yokoya_file["quady"])
        yokxquad = np.interp(ratio, ratio_col, yokoya_file["quadx"])
    else:
        yokxdip = np.interp(ratio, ratio_col, yokoya_file["dipy"])
        yokydip = np.interp(ratio, ratio_col, yokoya_file["dipx"])
        yokxquad = np.interp(ratio, ratio_col, yokoya_file["quady"])
        yokyquad = np.interp(ratio, ratio_col, yokoya_file["quadx"])        

    return (yoklong, yokxdip, yokydip, yokxquad, yokyquad)

def beam_loss_factor(impedance, frequency, spectrum, ring):
    """
    Compute "beam" loss factor using the beam spectrum, uses a sum instead of 
    integral compared to loss_factor [1].

    Parameters
    ----------
    impedance : Impedance of type "long"
    frequency : array
        Sample points of spectrum.
    spectrum : array
        Beam spectrum to consider.
    ring : Synchrotron object

    Returns
    -------
    kloss_beam : float
        Beam loss factor in [V/C].
        
    References
    ----------
    [1] : Handbook of accelerator physics and engineering, 3rd printing. 
        Eq (3) p239.
    """
    pmax = np.floor(impedance.data.index.max()/ring.f0)
    pmin = np.floor(impedance.data.index.min()/ring.f0)
    
    if pmin >= 0:
        double_sided_impedance(impedance)
        pmin = -1*pmax
    
    p = np.arange(pmin+1,pmax)    
    pf0 = p*ring.f0
    ReZ = np.real(impedance(pf0))
    spectral_density = np.abs(spectrum)**2
    # interpolation of the spectrum is needed to avoid problems liked to 
    # division by 0
    # computing the spectrum directly to the frequency points gives
    # wrong results
    spect = interp1d(frequency, spectral_density)
    kloss_beam = ring.f0 * np.sum(ReZ*spect(pf0))
    
    return kloss_beam

def double_sided_impedance(impedance):
    """
    Add negative frequency points to single sided impedance spectrum following
    symetries depending on impedance type.

    Parameters
    ----------
    impedance : Impedance object
        Single sided impedance.
    """
    fmin = impedance.data.index.min()
    
    if fmin >= 0:
        negative_index = impedance.data.index*-1
        negative_data = impedance.data.set_index(negative_index)
        
        imp_type = impedance.impedance_type
        
        if imp_type == "long":
            negative_data["imag"] = -1*negative_data["imag"]
            
        elif (imp_type == "xdip") or (imp_type == "ydip"):
            negative_data["real"] = -1*negative_data["real"]
        
        elif (imp_type == "xquad") or (imp_type == "yquad"):
            negative_data["real"] = -1*negative_data["real"]
            
        else:
            raise ValueError("Wrong impedance type")
            
        try:
            negative_data = negative_data.drop(0)
        except KeyError:
            pass
            
        all_data = impedance.data.append(negative_data)
        all_data = all_data.sort_index()
        impedance.data = all_data