Skip to content
Snippets Groups Projects
Select Git revision
  • main
  • release_2_2_1
  • release_2_2_0
  • release_2_1_3
  • release_2_1_2
  • release_2_1_1
  • release_2_1_0
  • release_2_0_2
  • release_2_0_1
  • release_2_0_0
  • release_1_0_0
11 results

pom.xml

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    beamloading.py 32.83 KiB
    # -*- coding: utf-8 -*-
    """
    Beam loading module
    Created on Fri Aug 23 13:32:03 2019
    
    @author: gamelina
    """
    
    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    import matplotlib.path as mpltPath
    from mpl_toolkits.mplot3d import Axes3D
    import sys
    from mpi4py import MPI
    from scipy.optimize import root, fsolve
    from scipy.constants import c
    from scipy.integrate import solve_ivp, quad, romb
    from scipy.interpolate import interp1d, griddata
    from scipy import real, imag
    
    class BeamLoadingVlasov():
        """Class used to compute beam equilibrium profile and stability for a given
        storage ring and a list of RF cavities of any harmonic. The class assumes
        an uniform filling of the storage ring. Based on an extension of [1].
    
        [1] Venturini, M. (2018). Passive higher-harmonic rf cavities with general
        settings and multibunch instabilities in electron storage rings.
        Physical Review Accelerators and Beams, 21(11), 114404.
    
        Parameters
        ----------
        ring : Synchrotron object
        cavity_list : list of CavityResonator objects
        I0 : beam current in [A].
        auto_set_MC_theta : if True, allow class to change generator phase for
            CavityResonator objetcs with m = 1
        F : list of form factor amplitude
        PHI : list of form factor phase
        B1 : lower intergration boundary
        B2 : upper intergration boundary
        """
    
        def __init__(
                    self, ring, cavity_list, I0, auto_set_MC_theta=False, F=None,
                    PHI=None, B1=-0.2, B2=0.2):
            self.ring = ring
            self.cavity_list = cavity_list
            self.I0 = I0
            self.n_cavity = len(cavity_list)
            self.auto_set_MC_theta = auto_set_MC_theta
            if F is None:
                self.F = np.ones((self.n_cavity,))
            else:
                self.F = F
            if PHI is None:
                self.PHI = np.zeros((self.n_cavity,))
            else:
                self.PHI = PHI
            self.B1 = B1
            self.B2 = B2
            self.mpi = False
            self.__version__ = "1.0"
    
            # Define constants for scaled potential u(z)
            self.u0 = self.ring.U0 / (
                self.ring.ac * self.ring.sigma_delta**2
                * self.ring.E0 * self.ring.L)
            self.ug = np.zeros((self.n_cavity,))
            self.ub = np.zeros((self.n_cavity,))
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                self.ug[i] = cavity.Vg / (
                    self.ring.ac * self.ring.sigma_delta ** 2 *
                    self.ring.E0 * self.ring.L * self.ring.k1 *
                    cavity.m)
                self.ub[i] = 2 * self.I0 * cavity.Rs / (
                    self.ring.ac * self.ring.sigma_delta**2 *
                    self.ring.E0 * self.ring.L * self.ring.k1 *
                    cavity.m * (1 + cavity.beta))
            
        def energy_balance(self):
            """Return energy balance for the synchronous particle
            (z = 0 ,delta = 0)."""
            delta = self.ring.U0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                delta += cavity.Vb(self.I0) * self.F[i] * np.cos(cavity.psi - self.PHI[i])
                delta -= cavity.Vg * np.cos(cavity.theta_g)
            return delta
    
        def u(self, z):
            """Scaled potential u(z)"""
            pot = self.u0 * z
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                pot += - self.ug[i] * (
                    np.sin(self.ring.k1 * cavity.m * z + cavity.theta_g)
                    - np.sin(cavity.theta_g))
                pot += self.ub[i] * self.F[i] * np.cos(cavity.psi) * (
                    np.sin(self.ring.k1 * cavity.m * z
                           + cavity.psi - self.PHI[i])
                    - np.sin(cavity.psi - self.PHI[i]))
            return pot
        
        def du_dz(self, z):
            """Partial derivative of the scaled potential u(z) by z"""
            pot = self.u0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                pot += - self.ug[i] * self.ring.k1 * cavity.m * np.cos(self.ring.k1 * cavity.m * z + cavity.theta_g)
                pot += self.ub[i] * self.F[i] * self.ring.k1 * cavity.m * np.cos(cavity.psi) * np.cos(self.ring.k1 * cavity.m * z + cavity.psi - self.PHI[i])
            return pot
    
        def uexp(self, z):
            return np.exp(-1 * self.u(z))
    
        def integrate_func(self, f, g):
            """Return Integral[f*g]/Integral[f] between B1 and B2"""
            A = quad(lambda x: f(x) * g(x), self.B1, self.B2)
            B = quad(f, self.B1, self.B2)
            return A[0] / B[0]
    
        def to_solve(self, x):
            """System of non-linear equation to solve to find the form factors F
            and PHI at equilibrum.
            The system is composed of Eq. (B6) and (B7) of [1] for each cavity.
            If auto_set_MC_theta == True, the system also find the generator phase
            to impose energy balance.
            """
            # Update values of F, PHI and theta_g
            if self.auto_set_MC_theta:
                self.F = x[:-1:2]
                for i in range(self.n_cavity):
                    cavity = self.cavity_list[i]
                    if cavity.m == 1:
                        cavity.theta_g = x[-1]
            else:
                self.F = x[::2]
            self.PHI = x[1::2]
    
            # Compute system
            if self.auto_set_MC_theta:
                res = np.zeros((self.n_cavity * 2 + 1,))
                for i in range(self.n_cavity):
                    cavity = self.cavity_list[i]
                    res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
                        lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
                    res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
                        lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
                # Factor 1e-8 for better convergence
                res[self.n_cavity * 2] = self.energy_balance() * 1e-8
            else:
                res = np.zeros((self.n_cavity * 2,))
                for i in range(self.n_cavity):
                    cavity = self.cavity_list[i]
                    res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
                        lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
                    res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
                        lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
            return res
    
        def rho(self, z):
            """Return bunch equilibrium profile at postion z"""
            A = quad(lambda y: self.uexp(y), self.B1, self.B2)
            return self.uexp(z) / A[0]
    
        def plot_rho(self, z1=None, z2=None):
            """Plot the bunch equilibrium profile between z1 and z2"""
            if z1 is None:
                z1 = self.B1
            if z2 is None:
                z2 = self.B2
            z0 = np.linspace(z1, z2, 1000)
            plt.plot(z0, self.rho(z0))
            plt.xlabel("z [m]")
            plt.title("Equilibrium bunch profile")
            
        def voltage(self, z):
            """Return the RF system total voltage at position z"""
            Vtot = 0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                Vtot += cavity.VRF(z, self.I0, self.F[i], self.PHI[i])
            return Vtot
        
        def dV(self, z):
            """Return derivative of the RF system total voltage at position z"""
            Vtot = 0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                Vtot += cavity.dVRF(z, self.I0, self.F[i], self.PHI[i])
            return Vtot
        
        def ddV(self, z):
            """Return the second derivative of the RF system total voltage at position z"""
            Vtot = 0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                Vtot += cavity.ddVRF(z, self.I0, self.F[i], self.PHI[i])
            return Vtot
        
        def deltaVRF(self, z):
            """Return the generator voltage minus beam loading voltage of the total RF system at position z"""
            Vtot = 0
            for i in range(self.n_cavity):
                cavity = self.cavity_list[i]
                Vtot += cavity.deltaVRF(z, self.I0, self.F[i], self.PHI[i])
            return Vtot
        
        def plot_dV(self, z1=None, z2=None):
            """Plot the derivative of RF system total voltage between z1 and z2"""
            if z1 is None:
                z1 = self.B1
            if z2 is None:
                z2 = self.B2
            z0 = np.linspace(z1, z2, 1000)
            plt.plot(z0, self.dV(z0))
            plt.xlabel("z [m]")
            plt.ylabel("Total RF voltage (V)")
            
        def plot_voltage(self, z1=None, z2=None):
            """Plot the RF system total voltage between z1 and z2"""
            if z1 is None:
                z1 = self.B1
            if z2 is None:
                z2 = self.B2
            z0 = np.linspace(z1, z2, 1000)
            plt.plot(z0, self.voltage(z0))
            plt.xlabel("z [m]")
            plt.ylabel("Total RF voltage (V)")
    
        def std_rho(self, z1=None, z2=None):
            """Return the rms bunch equilibrium size in [m]"""
            if z1 is None:
                z1 = self.B1
            if z2 is None:
                z2 = self.B2
            z0 = np.linspace(z1, z2, 1000)
            values = self.rho(z0)
            average = np.average(z0, weights=values)
            variance = np.average((z0 - average)**2, weights=values)
            return np.sqrt(variance)
    
        def beam_equilibrium(self, x0=None, tol=1e-4, method='hybr', options=None, plot = False):
            """Solve system of non-linear equation to find the form factors F
            and PHI at equilibrum. Can be used to compute the equilibrium bunch
            profile.
            
            Parameters
            ----------
            x0 : initial guess
            tol : tolerance for termination of the algorithm
            method : method used by scipy.optimize.root to solve the system
            options : options given to scipy.optimize.root
            plot : if True, plot the equilibrium bunch profile
            
            Returns
            -------
            sol : OptimizeResult object representing the solution
            """
            if x0 is None:
                x0 = [1, 0] * self.n_cavity
                if self.auto_set_MC_theta:
                    x0 = x0 + [self.cavity_list[0].theta_g]
    
            print("The initial energy balance is " +
                  str(self.energy_balance()) + " eV")
    
            sol = root(self.to_solve, x0, tol=tol, method=method, options=options)
    
            # Update values of F, PHI and theta_g
            if self.auto_set_MC_theta:
                self.F = sol.x[:-1:2]
                for i in range(self.n_cavity):
                    cavity = self.cavity_list[i]
                    if cavity.m == 1:
                        cavity.theta_g = sol.x[-1]
            else:
                self.F = sol.x[::2]
            self.PHI = sol.x[1::2]
    
            print("The final energy balance is " +
                  str(self.energy_balance()) + " eV")
            print("The algorithm has converged: " + str(sol.success))
            
            if plot:
                self.plot_rho(self.B1 / 4, self.B2 / 4)
    
            return sol
    
        def H0(self, z):
            """Unperturbed Hamiltonian for delta = 0"""
            return self.u(z)*self.ring.ac*c*self.ring.sigma_delta**2
    
        def H0_val(self, z, value):
            return self.H0(value) - self.H0(z)
        
        def Jfunc(self, z, zri_val):
            """Convenience function to compute the J integral"""
            return np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
                           - 2*self.u(z)*self.ring.sigma_delta**2)
    
        def Tsfunc(self, z, zri_val):
            """Convenience function to compute the Ts integral"""
            return 1/np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
                           - 2*self.u(z)*self.ring.sigma_delta**2)
            
        def dz_dt(self, delta):
            """dz/dt part of the equation of motion for synchrotron oscillation"""
            return self.ring.ac*c*delta
        
        def ddelta_dt(self, z):
            """ddelta/dt part of the equation of motion for synchrotron oscillation"""
            VRF = 0
            for i in range(self.n_cavity):
                cav = self.cavity_list[i]
                VRF += cav.VRF(z, self.I0, self.F[i],self.PHI[i])
            return (VRF - self.ring.U0)/(self.ring.E0*self.ring.T0)
                
        def EOMsystem(self, t, y):
            """System of equations of motion for synchrotron oscillation
            
            Parameters
            ----------
            t : time
            y[0] : z
            y[1] : delta
            """
            
            res = np.zeros(2,)
            res[0] = self.dz_dt(y[1])
            res[1] = self.ddelta_dt(y[0])
            return res
        
        def cannonical_transform(self, zmax = 4.5e-2, tmax = 2e-3, nz = 1000, 
                                 ntime = 1000, plot = False, epsilon = 1e-10,
                                 rtol = 1e-9, atol = 1e-13, epsrel = 1e-9,
                                 epsabs = 1e-11, zri0_coef = 1e-2, n_pow = 16,
                                 eps = 1e-9):
            """Compute cannonical transform zeta, see [1] appendix C.
    
            Parameters
            ----------
            zmax : maximum amplitude of the z-grid for the cannonical transform
                computation
            nz : number of point in the z-grid
            tmax : maximum amplitude of the time-grid
            ntime : number of point in the time-grid
            plot : if True, plot the surface map of the connonical transform
            epsilon : convergence criterium for the maximum value of uexp(zmax) 
                and uexp(zmin)
            rtol, atol : relative and absolute tolerances for the equation of 
                motion solving
            epsabs, epsrel : relative and absolute tolerances for the integration
                of J and Ts
            zri0_coef : lower boundary of the z-grid is given by dz * zri0_coef, to
                be lowered if lower values of z/J are needed in the grid
            n_pow : power of two used in the romb integration of J/Ts if the error
                on J/Ts is bigger than 0.1 % or if the value is nan
            eps : if the romb integration is used, the integration interval is
                limited to [zli[i] + eps,zri[i] - eps]
            """
            
            # step (i)
            
            sol = root(lambda z : self.H0_val(z,zmax), -zmax)
            zmin = sol.x
            
            if sol.success == False:
                raise Exception("Error, no solution for zmin was found : " + sol.message)
            
            if self.uexp(zmax) > epsilon or self.uexp(zmin) > epsilon:
                raise Exception("Error, uexp(zmax) < epsilon or uexp(zmin) < epsilon, try to increase zmax")
                
            # step (ii)
            
            dz = (zmax-zmin)/nz
            zri = np.arange(nz)*dz
            zri[0] = dz*zri0_coef
            zli = np.zeros((nz,))
            
            for i in range(nz):
                func = lambda z : self.H0_val(z,zri[i])
                zli[i], infodict, ier, mesg = fsolve(func,-zri[i], full_output = True)
                if ier != 1:
                    print("Error at i = " + str(i) + " : " + mesg)
                    
            # step (iii)
            
            J = np.zeros((nz,))
            J_err = np.zeros((nz,))
            Ts = np.zeros((nz,))
            Ts_err = np.zeros((nz,))
            
            for i in range(0,nz):
                J[i], J_err[i] = quad(lambda z : 1/np.pi*self.Jfunc(z,zri[i]), 
                     zli[i], zri[i], epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
                
                if J_err[i]/J[i] > 0.001:
                    print("Error is bigger than 0.1% for i = " + str(i) 
                        + ", J = " + str(J[i]) +
                        ", relative error on J = " + str(J_err[i]/J[i]))
                    x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                    y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
                    J[i] = romb(y0,x0[1]-x0[0])
                    print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
                    
                if np.isnan(J[i]):
                    print("J is nan for i = " + str(i) )
                    x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                    y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
                    J[i] = romb(y0,x0[1]-x0[0])
                    print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
                    
                Ts[i], Ts_err[i] = quad(lambda z : 2*self.Tsfunc(z,zri[i]) / 
                    (self.ring.ac*c), zli[i], zri[i],
                    epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
                
                if Ts_err[i]/Ts[i] > 0.001:
                    print("Error is bigger than 0.1% for i = " + str(i) 
                        + ", Ts = " + str(Ts[i]) +
                        ", relative error on Ts = " + str(Ts_err[i]/Ts[i]))
                    x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                    y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
                    Ts[i] = romb(y0,x0[1]-x0[0])
                    print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
                    
                if np.isnan(Ts[i]):
                    print("Ts is nan for i = " + str(i) )
                    x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
                    y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
                    Ts[i] = romb(y0,x0[1]-x0[0])
                    print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
                    
            Omegas = 2*np.pi/Ts
            zri[0] = 0 # approximation to dz*zri0_coef
            # step (iv)
        
            z_sol = np.zeros((nz,ntime))
            delta_sol = np.zeros((nz,ntime))
            phi_sol = np.zeros((nz,ntime))
                
            for i in range(nz):
                y0 = (zri[i], 0)
                tspan = (0, tmax)
                sol = solve_ivp(self.EOMsystem, tspan, y0, rtol = rtol, atol = atol)
                
                time_base = np.linspace(tspan[0],tspan[1],ntime)
                
                fz = interp1d(sol.t,sol.y[0],kind='cubic') # cubic decreases K std but increases a bit K mean 
                fdelta = interp1d(sol.t,sol.y[1],kind='cubic')
            
                z_sol[i,:] = fz(time_base)
                delta_sol[i,:] = fdelta(time_base)
                phi_sol[i,:] = Omegas[i]*time_base;
                
            # Plot the surface
            
            if plot:
                fig = plt.figure()
                ax = Axes3D(fig)
                surf = ax.plot_surface(J, phi_sol.T, z_sol.T, rcount = 75,
                                       ccount = 75, cmap =plt.cm.viridis,
                                       linewidth=0, antialiased=True)
                fig.colorbar(surf, shrink=0.5, aspect=5)
                ax.set_xlabel('J [m]')
                ax.set_ylabel('$\phi$ [rad]')
                ax.set_zlabel('z [m]')
            
            # Check the results by computing the Jacobian matrix
            
            dz_dp = np.zeros((nz,ntime))
            dz_dJ = np.zeros((nz,ntime))
            dd_dp = np.zeros((nz,ntime))
            dd_dJ = np.zeros((nz,ntime))
            K = np.zeros((nz,ntime))
            
            for i in range(1,nz-1):
                for j in range(1,ntime-1):
                    dz_dp[i,j] = (z_sol[i,j+1] - z_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
                    dz_dJ[i,j] = (z_sol[i+1,j] - z_sol[i-1,j])/(J[i+1] - J[i-1])
                    if (J[i+1] - J[i-1]) == 0 :
                        print(i)
                    dd_dp[i,j] = (delta_sol[i,j+1] - delta_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
                    dd_dJ[i,j] = (delta_sol[i+1,j] - delta_sol[i-1,j])/(J[i+1] - J[i-1])
            
                
            K = np.abs(dz_dp*dd_dJ - dd_dp*dz_dJ)
            K[0,:] = K[1,:]
            K[-1,:] = K[-2,:]
            K[:,0] = K[:,1]
            K[:,-1] = K[:,-2]
            
            print('Numerical Jacobian K: mean = ' + str(np.mean(K)) + ', value should be around 1.')
            print('Numerical Jacobian K: std = ' + str(np.std(K)) + ', value should be around 0.')
            print('Numerical Jacobian K: max = ' + str(np.max(K)) + ', value should be around 1.')
            print('If the values are far from their nominal values, decrease atol, rtol, epsrel and epsabs.')
            
            # Plot the numerical Jacobian
            
            if plot:
                fig = plt.figure()
                ax = fig.gca(projection='3d')
                surf = ax.plot_surface(J, phi_sol.T, K.T, rcount = 75, ccount = 75,
                                       cmap =plt.cm.viridis,
                                       linewidth=0, antialiased=True)
                fig.colorbar(surf, shrink=0.5, aspect=5)
                ax.set_xlabel('J [m]')
                ax.set_ylabel('$\phi$ [rad]')
                ax.set_zlabel('K')
                
            self.J = J
            self.phi = phi_sol.T
            self.z = z_sol.T
            self.delta = delta_sol.T
            self.Omegas = Omegas
            self.Ts = Ts
            
            JJ, bla = np.meshgrid(J,J)
            g = (JJ, phi_sol.T, z_sol.T)
            J_p, phi_p, z_p = np.vstack(map(np.ravel, g))
            
            self.J_p = J_p
            self.phi_p = phi_p
            self.z_p = z_p
            
            # Provide interpolation functions for methods
            self.func_J = interp1d(z_sol[:,0],J, bounds_error=True)
            self.func_omegas = interp1d(J,Omegas, bounds_error=True)
            
            # Provide polygon to check if points are inside the map
            p1 = [[J[i],phi_sol[i,:].max()] for i in range(J.size)]
            p2 = [[J[i],phi_sol[i,:].min()] for i in range(J.size-1,0,-1)]
            poly = p1+p2
            self.path = mpltPath.Path(poly)
            
        def dphi0(self, z, delta):
            """Partial derivative of the distribution function phi by z"""
            dphi0 = -np.exp(-delta**2/(2*self.ring.sigma_delta**2)) / (np.sqrt(2*np.pi)*self.ring.sigma_delta) * self.rho(z)*self.du_dz(z)
            return dphi0
                
        def zeta(self, J, phi):
            """Compute zeta cannonical transformation z = zeta(J,phi) using 2d interpolation
            interp2d and bisplrep does not seem to work for this case
            griddata is working fine if method='linear'
            griddata return nan if the point is outside the grid
            """
            z = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J, phi), method='linear')
            if np.isnan(np.sum(z)) == True:
                inside = self.path.contains_points(np.array([J, phi]).T)
                if (np.sum(inside != True) >= 1):
                    print("Some points are outside the grid of the cannonical transformation zeta !")
                    print(str(np.sum(inside != True)) + " points are outside the grid out of " + str(J.size) + " points.")
                    plot_points_outside = False
                    if(plot_points_outside):
                        fig = plt.figure()
                        ax = fig.add_subplot(111)
                        patch = patches.PathPatch(self.path, facecolor='orange', lw=2)
                        ax.add_patch(patch)
                        ax.set_xlim(self.J_p.min()*0.5,self.J_p.max()*1.5)
                        ax.set_ylim(self.phi_p.min()*0.5,self.phi_p.max()*1.5)
                        ax.scatter(J[np.nonzero(np.invert(inside))[0]],phi[np.nonzero(np.invert(inside))[0]])
                        plt.show()
                        vals = np.nonzero(np.invert(inside))[0]
                        print("J min = " + str(J[vals].min()))
                        print("J max = " + str(J[vals].max()))
                        print("Phi min = " + str(phi[vals].min()))
                        print("Phi max = " + str(phi[vals].max()))
                else:
                    print("Some values interpolated from the zeta cannonical transformation are nan !")
                    print("This may be due to a scipy version < 1.3")
                print("The nan values are remplaced by nearest values in the grid")
                arr = np.isnan(z)
                z[arr] = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J[arr], phi[arr]), method='nearest')
            if np.isnan(np.sum(z)) == True:
                print("Correction by nearest values in the grid failed, program is now stopping")
                raise ValueError("zeta cannonical transformation are nan")
            return z
                
        def H_coef(self, J_val, m, p, l, n_pow = 14):
            """Compute H_{m,p} (J_val) using zeta cannonical transformation"""
            """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
            omegap = self.ring.omega0*(p*self.ring.h + l)
            phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
            res = np.zeros(J_val.shape, dtype=complex)
            phi0_array = np.tile(phi0, J_val.size)
            J_array = np.array([])
            for i in range(J_val.size):
                J_val_array = np.tile(J_val[i], phi0.size)
                J_array = np.concatenate((J_array, J_val_array))
            zeta_values = self.zeta(J_array, phi0_array)
            y0 = np.exp(1j*m*phi0_array + 1j*omegap*zeta_values/c)
            for i in range(J_val.size):
                res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
            return res
        
        def H_coef_star(self, J_val, m, p, l, n_pow = 14):
            """Compute H_{m,p}^* (J_val) using zeta cannonical transformation"""
            """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
            omegap = self.ring.omega0*(p*self.ring.h + l)
            phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
            res = np.zeros(J_val.shape, dtype=complex)
            phi0_array = np.tile(phi0, J_val.size)
            J_array = np.array([])
            for i in range(J_val.size):
                J_val_array = np.tile(J_val[i], phi0.size)
                J_array = np.concatenate((J_array, J_val_array))
            zeta_values = self.zeta(J_array, phi0_array)
            y0 = np.exp( - 1j*m*phi0_array - 1j*omegap*zeta_values/c)
            for i in range(J_val.size):
                res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
            return res
        
        def HH_coef(self, J_val, m, p, pp, l, n_pow = 10):
            """Compute H_{m,p} (J_val) * H_{m,pp}^* (J_val) using zeta cannonical transformation"""
            """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
            omegapH = self.ring.omega0*(p*self.ring.h + l)
            omegapH_star = self.ring.omega0*(pp*self.ring.h + l)
            phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
            H = np.zeros(J_val.shape, dtype=complex)
            H_star = np.zeros(J_val.shape, dtype=complex)
            phi0_array = np.tile(phi0, J_val.size)
            J_array = np.array([])
            for i in range(J_val.size):
                J_val_array = np.tile(J_val[i], phi0.size)
                J_array = np.concatenate((J_array, J_val_array))
            zeta_values = self.zeta(J_array, phi0_array)
            y0 = np.exp( 1j*m*phi0_array + 1j*omegapH*zeta_values/c)
            y1 = np.exp( - 1j*m*phi0_array - 1j*omegapH_star*zeta_values/c)
            for i in range(J_val.size):
                H[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
                H_star[i] = romb(y1[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
            return H*H_star
        
        def J_from_z(self, z):
            """Return the action J corresponding to a given amplitude z_r,
            corresponds to the inversion of z_r = zeta(J,0) : J = zeta^-1(z)
            """
            return self.func_J(z)
        
        def Omegas_from_z(self,z):
            """Return the synchrotron angular frequency omega_s for a given amplitude z_r"""
            return self.func_omegas(self.J_from_z(z))
        
        def G(self,m,p,pp,l,omega,delta_val = 0, n_pow = 8, Gmin = 1e-8, Gmax = 4.5e-2):
            """Compute the G integral"""
            g_func = lambda z : self.HH_coef(self.J_from_z(z), m, p, pp, l)*self.dphi0(z,delta_val)/(omega - m*self.Omegas_from_z(z))
            z0 = np.linspace(Gmin,Gmax,2**n_pow+1)
            y0 = g_func(z0)
            G_val = romb(y0,z0[1]-z0[0])
            #if( np.abs(y0[0]/G_val) > 0.001 or np.abs(y0[-1]/G_val) > 0.001 ):
                #print("Integration boundaries for G value might be wrong.")
                #plt.plot(z0,y0)
            return G_val
        
        def mpi_init(self):
            """Switch on mpi"""
            
            self.mpi = True
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
                    
            if(rank == 0):
                pass
            else:
                while(True):
                    order = comm.bcast(None,0)
    
                    if(order == "init"):
                        VLASOV = comm.bcast(None,0)
                    
                    if(order == "B_matrix"):
                        Bsize = comm.bcast(None,0)
                        if Bsize**2 != comm.size - 1:
                            #raise ValueError("The number of processor must be Bsize**2 + 1, which is :",Bsize**2 + 1)
                            #sys.exit()
                            pass
                        omega = comm.bcast(None,0)
                        mmax = comm.bcast(None,0)
                        l_solve = comm.bcast(None,0)
                        
                        Ind_table = np.zeros((2,Bsize)) # m, p
                        for i in range(VLASOV.n_cavity):
                            cav = VLASOV.cavity_list[i]
                            Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
                            Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
                            Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
                            Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
                        
                        matrix_i = np.zeros((Bsize,Bsize))
                        matrix_j = np.zeros((Bsize,Bsize))
                        for i in range(Bsize):
                            for j in range(Bsize):
                                matrix_i[i,j] = i
                                matrix_j[i,j] = j
                        
                        i = int(matrix_i.flatten()[rank-1])
                        j = int(matrix_j.flatten()[rank-1])
                        
                        B = np.zeros((Bsize,Bsize), dtype=complex)
                        
                        omegap = VLASOV.ring.omega0*(Ind_table[1,j]*VLASOV.ring.h + l_solve)
                        Z = np.zeros((1,),dtype=complex)
                        for k in range(VLASOV.n_cavity):
                            cav = VLASOV.cavity_list[k]
                            Z += cav.Z(omegap + omega)
                        if i == j:
                            B[i,j] += 1
                        B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*VLASOV.I0/VLASOV.ring.E0/VLASOV.ring.T0*c*Z/omegap*VLASOV.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
                        
                        comm.Reduce([B, MPI.COMPLEX], None, op=MPI.SUM, root=0)
    
                    if(order == "stop"):
                        sys.exit()               
                        
        def mpi_exit(self):
            """Switch off mpi"""
            
            self.mpi = False
            comm = MPI.COMM_WORLD
            rank = comm.Get_rank()
            
            if(rank == 0):
                comm.bcast("stop",0)
        
        def detB(self,omega,mmax,l_solve):
            """Return the determinant of the matrix B"""
            Bsize = 2*self.n_cavity*(2*mmax + 1)
    
            if(self.mpi):
                comm = MPI.COMM_WORLD
                comm.bcast("B_matrix",0)
                comm.bcast(Bsize,0)
                comm.bcast(omega,0)
                comm.bcast(mmax,0)
                comm.bcast(l_solve,0)
                B = np.zeros((Bsize,Bsize), dtype=complex)
                comm.Reduce([np.zeros((Bsize,Bsize), dtype=complex), MPI.COMPLEX], [B, MPI.COMPLEX],op=MPI.SUM, root=0)
            else:
                Ind_table = np.zeros((2,Bsize)) # m, p
                for i in range(self.n_cavity):
                    cav = self.cavity_list[i]
                    Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
                    Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
                    Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
                    Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
                B = np.eye(Bsize, dtype=complex)
                for i in range(Bsize):
                    for j in range(Bsize):
                        omegap = self.ring.omega0*(Ind_table[1,j]*self.ring.h + l_solve)
                        Z = np.zeros((1,),dtype=complex)
                        for k in range(self.n_cavity):
                            cav = self.cavity_list[k]
                            Z += cav.Z(omegap + omega)
                        B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*self.I0/self.ring.E0/self.ring.T0*c*Z/omegap*self.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
            return np.linalg.det(B)
        
        def solveB(self,omega0,mmax,l_solve, maxfev = 200):
            """Solve equation detB = 0
            
            Parameters
            ----------
            omega0 : initial guess with omega0[0] the real part of the solution and omega0[1] the imaginary part
            mmax : maximum absolute value of m, see Eq. 20 of [1]
            l_solve : instability coupled-bunch mode number
            maxfev : the maximum number of calls to the function
            
            Returns
            -------
            omega : solution
            infodict : dictionary of scipy.optimize.fsolve ouput
            ier : interger flag, set to 1 if a solution was found
            mesg : if no solution is found, mesg details the cause of failure
            """
            
            def func(omega):
                res = self.detB(omega[0] + 1j*omega[1],mmax,l_solve)
                return real(res), imag(res)
    
            if not self.mpi:
                omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
            elif self.mpi and MPI.COMM_WORLD.rank == 0:
                comm = MPI.COMM_WORLD
                comm.bcast("init",0)
                comm.bcast(self, 0)
                omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
            
            if ier != 1:
                print("The algorithm has not converged: " + str(ier))
                print(mesg)
                
            return omega, infodict, ier, mesg