# -*- coding: utf-8 -*-
"""
Beam loading module
Created on Fri Aug 23 13:32:03 2019

@author: gamelina
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.path as mpltPath
from mpl_toolkits.mplot3d import Axes3D
import sys
from mpi4py import MPI
from scipy.optimize import root, fsolve
from scipy.constants import c
from scipy.integrate import solve_ivp, quad, romb
from scipy.interpolate import interp1d, griddata
from scipy import real, imag

class BeamLoadingVlasov():
    """Class used to compute beam equilibrium profile and stability for a given
    storage ring and a list of RF cavities of any harmonic. The class assumes
    an uniform filling of the storage ring. Based on an extension of [1].

    [1] Venturini, M. (2018). Passive higher-harmonic rf cavities with general
    settings and multibunch instabilities in electron storage rings.
    Physical Review Accelerators and Beams, 21(11), 114404.

    Parameters
    ----------
    ring : Synchrotron object
    cavity_list : list of CavityResonator objects
    I0 : beam current in [A].
    auto_set_MC_theta : if True, allow class to change cavity phase for
        CavityResonator objetcs with m = 1 (i.e. main cavities)
    F : list of form factor amplitude
    PHI : list of form factor phase
    B1 : lower intergration boundary
    B2 : upper intergration boundary
    """

    def __init__(
                self, ring, cavity_list, I0, auto_set_MC_theta=False, F=None,
                PHI=None, B1=-0.2, B2=0.2):
        self.ring = ring
        self.cavity_list = cavity_list
        self.I0 = I0
        self.n_cavity = len(cavity_list)
        self.auto_set_MC_theta = auto_set_MC_theta
        if F is None:
            self.F = np.ones((self.n_cavity,))
        else:
            self.F = F
        if PHI is None:
            self.PHI = np.zeros((self.n_cavity,))
        else:
            self.PHI = PHI
        self.B1 = B1
        self.B2 = B2
        self.mpi = False
        self.__version__ = "1.0"

        # Define constants for scaled potential u(z)
        self.u0 = self.ring.U0 / (
            self.ring.ac * self.ring.sigma_delta**2
            * self.ring.E0 * self.ring.L)
        self.ug = np.zeros((self.n_cavity,))
        self.ub = np.zeros((self.n_cavity,))
        self.update_potentials()
            
    def update_potentials(self):
        """Update potentials with cavity and ring data."""
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            self.ug[i] = cavity.Vg / (
                self.ring.ac * self.ring.sigma_delta ** 2 *
                self.ring.E0 * self.ring.L * self.ring.k1 *
                cavity.m)
            self.ub[i] = 2 * self.I0 * cavity.Rs / (
                self.ring.ac * self.ring.sigma_delta**2 *
                self.ring.E0 * self.ring.L * self.ring.k1 *
                cavity.m * (1 + cavity.beta))
        
    def energy_balance(self):
        """Return energy balance for the synchronous particle
        (z = 0 ,delta = 0)."""
        delta = self.ring.U0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            delta += cavity.Vb(self.I0) * self.F[i] * np.cos(cavity.psi - self.PHI[i])
            delta -= cavity.Vg * np.cos(cavity.theta_g)
        return delta
    
    def center_of_mass(self):
        """Return center of mass position in [s]"""
        z0 = np.linspace(self.B1, self.B2, 1000)
        rho = self.rho(z0)
        CM = np.average(z0, weights=rho)
        return CM/c

    def u(self, z):
        """Scaled potential u(z)"""
        pot = self.u0 * z
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            pot += - self.ug[i] * (
                np.sin(self.ring.k1 * cavity.m * z + cavity.theta_g)
                - np.sin(cavity.theta_g))
            pot += self.ub[i] * self.F[i] * np.cos(cavity.psi) * (
                np.sin(self.ring.k1 * cavity.m * z
                       + cavity.psi - self.PHI[i])
                - np.sin(cavity.psi - self.PHI[i]))
        return pot
    
    def du_dz(self, z):
        """Partial derivative of the scaled potential u(z) by z"""
        pot = self.u0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            pot += - self.ug[i] * self.ring.k1 * cavity.m * np.cos(self.ring.k1 * cavity.m * z + cavity.theta_g)
            pot += self.ub[i] * self.F[i] * self.ring.k1 * cavity.m * np.cos(cavity.psi) * np.cos(self.ring.k1 * cavity.m * z + cavity.psi - self.PHI[i])
        return pot

    def uexp(self, z):
        return np.exp(-1 * self.u(z))

    def integrate_func(self, f, g):
        """Return Integral[f*g]/Integral[f] between B1 and B2"""
        A = quad(lambda x: f(x) * g(x), self.B1, self.B2)
        B = quad(f, self.B1, self.B2)
        return A[0] / B[0]

    def to_solve(self, x, CM=True):
        """System of non-linear equation to solve to find the form factors F
        and PHI at equilibrum.
        The system is composed of Eq. (B6) and (B7) of [1] for each cavity.
        If auto_set_MC_theta == True, the system also find the main cavity 
        phase to impose energy balance or cancel center of mass offset.
        If CM is True, the system imposes zero center of mass offset,
        if False, the system imposes energy balance.
        """
        # Update values of F, PHI and theta
        if self.auto_set_MC_theta:
            self.F = x[:-1:2]
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                if cavity.m == 1:
                    cavity.theta = x[-1]
                    cavity.set_generator(0.5)
                    self.update_potentials()
        else:
            self.F = x[::2]
        self.PHI = x[1::2]

        # Compute system
        if self.auto_set_MC_theta:
            res = np.zeros((self.n_cavity * 2 + 1,))
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
                    lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
                res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
                    lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
            # Factor 1e-8 or 1e12 for better convergence
            if CM is True:
                res[self.n_cavity * 2] = self.center_of_mass() * 1e12
            else:
                res[self.n_cavity * 2] = self.energy_balance() * 1e-8
        else:
            res = np.zeros((self.n_cavity * 2,))
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
                    lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
                res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
                    lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
        return res

    def rho(self, z):
        """Return bunch equilibrium profile at postion z"""
        A = quad(lambda y: self.uexp(y), self.B1, self.B2)
        return self.uexp(z) / A[0]

    def plot_rho(self, z1=None, z2=None):
        """Plot the bunch equilibrium profile between z1 and z2"""
        if z1 is None:
            z1 = self.B1
        if z2 is None:
            z2 = self.B2
        z0 = np.linspace(z1, z2, 1000)
        plt.plot(z0, self.rho(z0))
        plt.xlabel("z [m]")
        plt.title("Equilibrium bunch profile")
        
    def voltage(self, z):
        """Return the RF system total voltage at position z"""
        Vtot = 0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            Vtot += cavity.VRF(z, self.I0, self.F[i], self.PHI[i])
        return Vtot
    
    def dV(self, z):
        """Return derivative of the RF system total voltage at position z"""
        Vtot = 0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            Vtot += cavity.dVRF(z, self.I0, self.F[i], self.PHI[i])
        return Vtot
    
    def ddV(self, z):
        """Return the second derivative of the RF system total voltage at position z"""
        Vtot = 0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            Vtot += cavity.ddVRF(z, self.I0, self.F[i], self.PHI[i])
        return Vtot
    
    def deltaVRF(self, z):
        """Return the generator voltage minus beam loading voltage of the total RF system at position z"""
        Vtot = 0
        for i in range(self.n_cavity):
            cavity = self.cavity_list[i]
            Vtot += cavity.deltaVRF(z, self.I0, self.F[i], self.PHI[i])
        return Vtot
    
    def plot_dV(self, z1=None, z2=None):
        """Plot the derivative of RF system total voltage between z1 and z2"""
        if z1 is None:
            z1 = self.B1
        if z2 is None:
            z2 = self.B2
        z0 = np.linspace(z1, z2, 1000)
        plt.plot(z0, self.dV(z0))
        plt.xlabel("z [m]")
        plt.ylabel("Total RF voltage (V)")
        
    def plot_voltage(self, z1=None, z2=None):
        """Plot the RF system total voltage between z1 and z2"""
        if z1 is None:
            z1 = self.B1
        if z2 is None:
            z2 = self.B2
        z0 = np.linspace(z1, z2, 1000)
        plt.plot(z0, self.voltage(z0))
        plt.xlabel("z [m]")
        plt.ylabel("Total RF voltage (V)")

    def std_rho(self, z1=None, z2=None):
        """Return the rms bunch equilibrium size in [m]"""
        if z1 is None:
            z1 = self.B1
        if z2 is None:
            z2 = self.B2
        z0 = np.linspace(z1, z2, 1000)
        values = self.rho(z0)
        average = np.average(z0, weights=values)
        variance = np.average((z0 - average)**2, weights=values)
        return np.sqrt(variance)

    def beam_equilibrium(self, x0=None, tol=1e-4, method='hybr', options=None, 
                         plot = False, CM=True):
        """Solve system of non-linear equation to find the form factors F
        and PHI at equilibrum. Can be used to compute the equilibrium bunch
        profile.
        
        Parameters
        ----------
        x0 : initial guess
        tol : tolerance for termination of the algorithm
        method : method used by scipy.optimize.root to solve the system
        options : options given to scipy.optimize.root
        plot : if True, plot the equilibrium bunch profile
        CM : if True, the system imposes zero center of mass offset,
        if False, the system imposes energy balance.
        
        Returns
        -------
        sol : OptimizeResult object representing the solution
        """
        if x0 is None:
            x0 = [1, 0] * self.n_cavity
            if self.auto_set_MC_theta:
                x0 = x0 + [self.cavity_list[0].theta]

        if CM:
            print("The initial center of mass offset is " +
                  str(self.center_of_mass()*1e12) + " ps")
        else:
            print("The initial energy balance is " +
                  str(self.energy_balance()) + " eV")

        sol = root(lambda x : self.to_solve(x, CM), x0, tol=tol, method=method, 
                   options=options)

        # Update values of F, PHI and theta_g
        if self.auto_set_MC_theta:
            self.F = sol.x[:-1:2]
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                if cavity.m == 1:
                    cavity.theta = sol.x[-1]
        else:
            self.F = sol.x[::2]
        self.PHI = sol.x[1::2]

        if CM:
            print("The final center of mass offset is " +
                  str(self.center_of_mass()*1e12) + " ps")
        else:
            print("The final energy balance is " +
                  str(self.energy_balance()) + " eV")
        print("The algorithm has converged: " + str(sol.success))
        
        if plot:
            self.plot_rho(self.B1 / 4, self.B2 / 4)

        return sol

    def H0(self, z):
        """Unperturbed Hamiltonian for delta = 0"""
        return self.u(z)*self.ring.ac*c*self.ring.sigma_delta**2

    def H0_val(self, z, value):
        return self.H0(value) - self.H0(z)
    
    def Jfunc(self, z, zri_val):
        """Convenience function to compute the J integral"""
        return np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
                       - 2*self.u(z)*self.ring.sigma_delta**2)

    def Tsfunc(self, z, zri_val):
        """Convenience function to compute the Ts integral"""
        return 1/np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
                       - 2*self.u(z)*self.ring.sigma_delta**2)
        
    def dz_dt(self, delta):
        """dz/dt part of the equation of motion for synchrotron oscillation"""
        return self.ring.ac*c*delta
    
    def ddelta_dt(self, z):
        """ddelta/dt part of the equation of motion for synchrotron oscillation"""
        VRF = 0
        for i in range(self.n_cavity):
            cav = self.cavity_list[i]
            VRF += cav.VRF(z, self.I0, self.F[i],self.PHI[i])
        return (VRF - self.ring.U0)/(self.ring.E0*self.ring.T0)
            
    def EOMsystem(self, t, y):
        """System of equations of motion for synchrotron oscillation
        
        Parameters
        ----------
        t : time
        y[0] : z
        y[1] : delta
        """
        
        res = np.zeros(2,)
        res[0] = self.dz_dt(y[1])
        res[1] = self.ddelta_dt(y[0])
        return res
    
    def cannonical_transform(self, zmax = 4.5e-2, tmax = 2e-3, nz = 1000, 
                             ntime = 1000, plot = False, epsilon = 1e-10,
                             rtol = 1e-9, atol = 1e-13, epsrel = 1e-9,
                             epsabs = 1e-11, zri0_coef = 1e-2, n_pow = 16,
                             eps = 1e-9):
        """Compute cannonical transform zeta, see [1] appendix C.

        Parameters
        ----------
        zmax : maximum amplitude of the z-grid for the cannonical transform
            computation
        nz : number of point in the z-grid
        tmax : maximum amplitude of the time-grid
        ntime : number of point in the time-grid
        plot : if True, plot the surface map of the connonical transform
        epsilon : convergence criterium for the maximum value of uexp(zmax) 
            and uexp(zmin)
        rtol, atol : relative and absolute tolerances for the equation of 
            motion solving
        epsabs, epsrel : relative and absolute tolerances for the integration
            of J and Ts
        zri0_coef : lower boundary of the z-grid is given by dz * zri0_coef, to
            be lowered if lower values of z/J are needed in the grid
        n_pow : power of two used in the romb integration of J/Ts if the error
            on J/Ts is bigger than 0.1 % or if the value is nan
        eps : if the romb integration is used, the integration interval is
            limited to [zli[i] + eps,zri[i] - eps]
        """
        
        # step (i)
        
        sol = root(lambda z : self.H0_val(z,zmax), -zmax)
        zmin = sol.x
        
        if sol.success == False:
            raise Exception("Error, no solution for zmin was found : " + sol.message)
        
        if self.uexp(zmax) > epsilon or self.uexp(zmin) > epsilon:
            raise Exception("Error, uexp(zmax) < epsilon or uexp(zmin) < epsilon, try to increase zmax")
            
        # step (ii)
        
        dz = (zmax-zmin)/nz
        zri = np.arange(nz)*dz
        zri[0] = dz*zri0_coef
        zli = np.zeros((nz,))
        
        for i in range(nz):
            func = lambda z : self.H0_val(z,zri[i])
            zli[i], infodict, ier, mesg = fsolve(func,-zri[i], full_output = True)
            if ier != 1:
                print("Error at i = " + str(i) + " : " + mesg)
                
        # step (iii)
        
        J = np.zeros((nz,))
        J_err = np.zeros((nz,))
        Ts = np.zeros((nz,))
        Ts_err = np.zeros((nz,))
        
        for i in range(0,nz):
            J[i], J_err[i] = quad(lambda z : 1/np.pi*self.Jfunc(z,zri[i]), 
                 zli[i], zri[i], epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
            
            if J_err[i]/J[i] > 0.001:
                print("Error is bigger than 0.1% for i = " + str(i) 
                    + ", J = " + str(J[i]) +
                    ", relative error on J = " + str(J_err[i]/J[i]))
                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
                J[i] = romb(y0,x0[1]-x0[0])
                print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
                
            if np.isnan(J[i]):
                print("J is nan for i = " + str(i) )
                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
                J[i] = romb(y0,x0[1]-x0[0])
                print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
                
            Ts[i], Ts_err[i] = quad(lambda z : 2*self.Tsfunc(z,zri[i]) / 
                (self.ring.ac*c), zli[i], zri[i],
                epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
            
            if Ts_err[i]/Ts[i] > 0.001:
                print("Error is bigger than 0.1% for i = " + str(i) 
                    + ", Ts = " + str(Ts[i]) +
                    ", relative error on Ts = " + str(Ts_err[i]/Ts[i]))
                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
                Ts[i] = romb(y0,x0[1]-x0[0])
                print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
                
            if np.isnan(Ts[i]):
                print("Ts is nan for i = " + str(i) )
                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
                Ts[i] = romb(y0,x0[1]-x0[0])
                print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
                
        Omegas = 2*np.pi/Ts
        zri[0] = 0 # approximation to dz*zri0_coef
        # step (iv)
    
        z_sol = np.zeros((nz,ntime))
        delta_sol = np.zeros((nz,ntime))
        phi_sol = np.zeros((nz,ntime))
            
        for i in range(nz):
            y0 = (zri[i], 0)
            tspan = (0, tmax)
            sol = solve_ivp(self.EOMsystem, tspan, y0, rtol = rtol, atol = atol)
            
            time_base = np.linspace(tspan[0],tspan[1],ntime)
            
            fz = interp1d(sol.t,sol.y[0],kind='cubic') # cubic decreases K std but increases a bit K mean 
            fdelta = interp1d(sol.t,sol.y[1],kind='cubic')
        
            z_sol[i,:] = fz(time_base)
            delta_sol[i,:] = fdelta(time_base)
            phi_sol[i,:] = Omegas[i]*time_base;
            
        # Plot the surface
        
        if plot:
            fig = plt.figure()
            ax = Axes3D(fig)
            surf = ax.plot_surface(J, phi_sol.T, z_sol.T, rcount = 75,
                                   ccount = 75, cmap =plt.cm.viridis,
                                   linewidth=0, antialiased=True)
            fig.colorbar(surf, shrink=0.5, aspect=5)
            ax.set_xlabel('J [m]')
            ax.set_ylabel('$\phi$ [rad]')
            ax.set_zlabel('z [m]')
        
        # Check the results by computing the Jacobian matrix
        
        dz_dp = np.zeros((nz,ntime))
        dz_dJ = np.zeros((nz,ntime))
        dd_dp = np.zeros((nz,ntime))
        dd_dJ = np.zeros((nz,ntime))
        K = np.zeros((nz,ntime))
        
        for i in range(1,nz-1):
            for j in range(1,ntime-1):
                dz_dp[i,j] = (z_sol[i,j+1] - z_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
                dz_dJ[i,j] = (z_sol[i+1,j] - z_sol[i-1,j])/(J[i+1] - J[i-1])
                if (J[i+1] - J[i-1]) == 0 :
                    print(i)
                dd_dp[i,j] = (delta_sol[i,j+1] - delta_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
                dd_dJ[i,j] = (delta_sol[i+1,j] - delta_sol[i-1,j])/(J[i+1] - J[i-1])
        
            
        K = np.abs(dz_dp*dd_dJ - dd_dp*dz_dJ)
        K[0,:] = K[1,:]
        K[-1,:] = K[-2,:]
        K[:,0] = K[:,1]
        K[:,-1] = K[:,-2]
        
        print('Numerical Jacobian K: mean = ' + str(np.mean(K)) + ', value should be around 1.')
        print('Numerical Jacobian K: std = ' + str(np.std(K)) + ', value should be around 0.')
        print('Numerical Jacobian K: max = ' + str(np.max(K)) + ', value should be around 1.')
        print('If the values are far from their nominal values, decrease atol, rtol, epsrel and epsabs.')
        
        # Plot the numerical Jacobian
        
        if plot:
            fig = plt.figure()
            ax = fig.gca(projection='3d')
            surf = ax.plot_surface(J, phi_sol.T, K.T, rcount = 75, ccount = 75,
                                   cmap =plt.cm.viridis,
                                   linewidth=0, antialiased=True)
            fig.colorbar(surf, shrink=0.5, aspect=5)
            ax.set_xlabel('J [m]')
            ax.set_ylabel('$\phi$ [rad]')
            ax.set_zlabel('K')
            
        self.J = J
        self.phi = phi_sol.T
        self.z = z_sol.T
        self.delta = delta_sol.T
        self.Omegas = Omegas
        self.Ts = Ts
        
        JJ, bla = np.meshgrid(J,J)
        g = (JJ, phi_sol.T, z_sol.T)
        J_p, phi_p, z_p = np.vstack(map(np.ravel, g))
        
        self.J_p = J_p
        self.phi_p = phi_p
        self.z_p = z_p
        
        # Provide interpolation functions for methods
        self.func_J = interp1d(z_sol[:,0],J, bounds_error=True)
        self.func_omegas = interp1d(J,Omegas, bounds_error=True)
        
        # Provide polygon to check if points are inside the map
        p1 = [[J[i],phi_sol[i,:].max()] for i in range(J.size)]
        p2 = [[J[i],phi_sol[i,:].min()] for i in range(J.size-1,0,-1)]
        poly = p1+p2
        self.path = mpltPath.Path(poly)
        
    def dphi0(self, z, delta):
        """Partial derivative of the distribution function phi by z"""
        dphi0 = -np.exp(-delta**2/(2*self.ring.sigma_delta**2)) / (np.sqrt(2*np.pi)*self.ring.sigma_delta) * self.rho(z)*self.du_dz(z)
        return dphi0
            
    def zeta(self, J, phi):
        """Compute zeta cannonical transformation z = zeta(J,phi) using 2d interpolation
        interp2d and bisplrep does not seem to work for this case
        griddata is working fine if method='linear'
        griddata return nan if the point is outside the grid
        """
        z = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J, phi), method='linear')
        if np.isnan(np.sum(z)) == True:
            inside = self.path.contains_points(np.array([J, phi]).T)
            if (np.sum(inside != True) >= 1):
                print("Some points are outside the grid of the cannonical transformation zeta !")
                print(str(np.sum(inside != True)) + " points are outside the grid out of " + str(J.size) + " points.")
                plot_points_outside = False
                if(plot_points_outside):
                    fig = plt.figure()
                    ax = fig.add_subplot(111)
                    patch = patches.PathPatch(self.path, facecolor='orange', lw=2)
                    ax.add_patch(patch)
                    ax.set_xlim(self.J_p.min()*0.5,self.J_p.max()*1.5)
                    ax.set_ylim(self.phi_p.min()*0.5,self.phi_p.max()*1.5)
                    ax.scatter(J[np.nonzero(np.invert(inside))[0]],phi[np.nonzero(np.invert(inside))[0]])
                    plt.show()
                    vals = np.nonzero(np.invert(inside))[0]
                    print("J min = " + str(J[vals].min()))
                    print("J max = " + str(J[vals].max()))
                    print("Phi min = " + str(phi[vals].min()))
                    print("Phi max = " + str(phi[vals].max()))
            else:
                print("Some values interpolated from the zeta cannonical transformation are nan !")
                print("This may be due to a scipy version < 1.3")
            print("The nan values are remplaced by nearest values in the grid")
            arr = np.isnan(z)
            z[arr] = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J[arr], phi[arr]), method='nearest')
        if np.isnan(np.sum(z)) == True:
            print("Correction by nearest values in the grid failed, program is now stopping")
            raise ValueError("zeta cannonical transformation are nan")
        return z
            
    def H_coef(self, J_val, m, p, l, n_pow = 14):
        """Compute H_{m,p} (J_val) using zeta cannonical transformation"""
        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
        omegap = self.ring.omega0*(p*self.ring.h + l)
        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
        res = np.zeros(J_val.shape, dtype=complex)
        phi0_array = np.tile(phi0, J_val.size)
        J_array = np.array([])
        for i in range(J_val.size):
            J_val_array = np.tile(J_val[i], phi0.size)
            J_array = np.concatenate((J_array, J_val_array))
        zeta_values = self.zeta(J_array, phi0_array)
        y0 = np.exp(1j*m*phi0_array + 1j*omegap*zeta_values/c)
        for i in range(J_val.size):
            res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
        return res
    
    def H_coef_star(self, J_val, m, p, l, n_pow = 14):
        """Compute H_{m,p}^* (J_val) using zeta cannonical transformation"""
        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
        omegap = self.ring.omega0*(p*self.ring.h + l)
        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
        res = np.zeros(J_val.shape, dtype=complex)
        phi0_array = np.tile(phi0, J_val.size)
        J_array = np.array([])
        for i in range(J_val.size):
            J_val_array = np.tile(J_val[i], phi0.size)
            J_array = np.concatenate((J_array, J_val_array))
        zeta_values = self.zeta(J_array, phi0_array)
        y0 = np.exp( - 1j*m*phi0_array - 1j*omegap*zeta_values/c)
        for i in range(J_val.size):
            res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
        return res
    
    def HH_coef(self, J_val, m, p, pp, l, n_pow = 10):
        """Compute H_{m,p} (J_val) * H_{m,pp}^* (J_val) using zeta cannonical transformation"""
        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
        omegapH = self.ring.omega0*(p*self.ring.h + l)
        omegapH_star = self.ring.omega0*(pp*self.ring.h + l)
        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
        H = np.zeros(J_val.shape, dtype=complex)
        H_star = np.zeros(J_val.shape, dtype=complex)
        phi0_array = np.tile(phi0, J_val.size)
        J_array = np.array([])
        for i in range(J_val.size):
            J_val_array = np.tile(J_val[i], phi0.size)
            J_array = np.concatenate((J_array, J_val_array))
        zeta_values = self.zeta(J_array, phi0_array)
        y0 = np.exp( 1j*m*phi0_array + 1j*omegapH*zeta_values/c)
        y1 = np.exp( - 1j*m*phi0_array - 1j*omegapH_star*zeta_values/c)
        for i in range(J_val.size):
            H[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
            H_star[i] = romb(y1[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
        return H*H_star
    
    def J_from_z(self, z):
        """Return the action J corresponding to a given amplitude z_r,
        corresponds to the inversion of z_r = zeta(J,0) : J = zeta^-1(z)
        """
        return self.func_J(z)
    
    def Omegas_from_z(self,z):
        """Return the synchrotron angular frequency omega_s for a given amplitude z_r"""
        return self.func_omegas(self.J_from_z(z))
    
    def G(self,m,p,pp,l,omega,delta_val = 0, n_pow = 8, Gmin = 1e-8, Gmax = 4.5e-2):
        """Compute the G integral"""
        g_func = lambda z : self.HH_coef(self.J_from_z(z), m, p, pp, l)*self.dphi0(z,delta_val)/(omega - m*self.Omegas_from_z(z))
        z0 = np.linspace(Gmin,Gmax,2**n_pow+1)
        y0 = g_func(z0)
        G_val = romb(y0,z0[1]-z0[0])
        #if( np.abs(y0[0]/G_val) > 0.001 or np.abs(y0[-1]/G_val) > 0.001 ):
            #print("Integration boundaries for G value might be wrong.")
            #plt.plot(z0,y0)
        return G_val
    
    def mpi_init(self):
        """Switch on mpi"""
        
        self.mpi = True
        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
                
        if(rank == 0):
            pass
        else:
            while(True):
                order = comm.bcast(None,0)

                if(order == "init"):
                    VLASOV = comm.bcast(None,0)
                
                if(order == "B_matrix"):
                    Bsize = comm.bcast(None,0)
                    if Bsize**2 != comm.size - 1:
                        #raise ValueError("The number of processor must be Bsize**2 + 1, which is :",Bsize**2 + 1)
                        #sys.exit()
                        pass
                    omega = comm.bcast(None,0)
                    mmax = comm.bcast(None,0)
                    l_solve = comm.bcast(None,0)
                    
                    Ind_table = np.zeros((2,Bsize)) # m, p
                    for i in range(VLASOV.n_cavity):
                        cav = VLASOV.cavity_list[i]
                        Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
                        Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
                        Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
                        Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
                    
                    matrix_i = np.zeros((Bsize,Bsize))
                    matrix_j = np.zeros((Bsize,Bsize))
                    for i in range(Bsize):
                        for j in range(Bsize):
                            matrix_i[i,j] = i
                            matrix_j[i,j] = j
                    
                    i = int(matrix_i.flatten()[rank-1])
                    j = int(matrix_j.flatten()[rank-1])
                    
                    B = np.zeros((Bsize,Bsize), dtype=complex)
                    
                    omegap = VLASOV.ring.omega0*(Ind_table[1,j]*VLASOV.ring.h + l_solve)
                    Z = np.zeros((1,),dtype=complex)
                    for k in range(VLASOV.n_cavity):
                        cav = VLASOV.cavity_list[k]
                        Z += cav.Z(omegap + omega)
                    if i == j:
                        B[i,j] += 1
                    B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*VLASOV.I0/VLASOV.ring.E0/VLASOV.ring.T0*c*Z/omegap*VLASOV.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
                    
                    comm.Reduce([B, MPI.COMPLEX], None, op=MPI.SUM, root=0)

                if(order == "stop"):
                    sys.exit()               
                    
    def mpi_exit(self):
        """Switch off mpi"""
        
        self.mpi = False
        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
        
        if(rank == 0):
            comm.bcast("stop",0)
    
    def detB(self,omega,mmax,l_solve):
        """Return the determinant of the matrix B"""
        Bsize = 2*self.n_cavity*(2*mmax + 1)

        if(self.mpi):
            comm = MPI.COMM_WORLD
            comm.bcast("B_matrix",0)
            comm.bcast(Bsize,0)
            comm.bcast(omega,0)
            comm.bcast(mmax,0)
            comm.bcast(l_solve,0)
            B = np.zeros((Bsize,Bsize), dtype=complex)
            comm.Reduce([np.zeros((Bsize,Bsize), dtype=complex), MPI.COMPLEX], [B, MPI.COMPLEX],op=MPI.SUM, root=0)
        else:
            Ind_table = np.zeros((2,Bsize)) # m, p
            for i in range(self.n_cavity):
                cav = self.cavity_list[i]
                Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
                Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
                Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
                Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
            B = np.eye(Bsize, dtype=complex)
            for i in range(Bsize):
                for j in range(Bsize):
                    omegap = self.ring.omega0*(Ind_table[1,j]*self.ring.h + l_solve)
                    Z = np.zeros((1,),dtype=complex)
                    for k in range(self.n_cavity):
                        cav = self.cavity_list[k]
                        Z += cav.Z(omegap + omega)
                    B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*self.I0/self.ring.E0/self.ring.T0*c*Z/omegap*self.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
        return np.linalg.det(B)
    
    def solveB(self,omega0,mmax,l_solve, maxfev = 200):
        """Solve equation detB = 0
        
        Parameters
        ----------
        omega0 : initial guess with omega0[0] the real part of the solution and omega0[1] the imaginary part
        mmax : maximum absolute value of m, see Eq. 20 of [1]
        l_solve : instability coupled-bunch mode number
        maxfev : the maximum number of calls to the function
        
        Returns
        -------
        omega : solution
        infodict : dictionary of scipy.optimize.fsolve ouput
        ier : interger flag, set to 1 if a solution was found
        mesg : if no solution is found, mesg details the cause of failure
        """
        
        def func(omega):
            res = self.detB(omega[0] + 1j*omega[1],mmax,l_solve)
            return real(res), imag(res)

        if not self.mpi:
            omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
        elif self.mpi and MPI.COMM_WORLD.rank == 0:
            comm = MPI.COMM_WORLD
            comm.bcast("init",0)
            comm.bcast(self, 0)
            omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
        
        if ier != 1:
            print("The algorithm has not converged: " + str(ier))
            print(mesg)
            
        return omega, infodict, ier, mesg