From c5a79d2495f08ecdc98df5ef4c3c7450b72ef767 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 30 Mar 2022 15:25:27 +0200
Subject: [PATCH] Refactor beamloading module

Remove untested code
Rename BeamLoadingVlasov to BeamLoadingEquilibrium
---
 collective_effects/instabilities.py |  83 +----
 vlasov/__init__.py                  |   2 +-
 vlasov/beamloading.py               | 521 +---------------------------
 3 files changed, 6 insertions(+), 600 deletions(-)

diff --git a/collective_effects/instabilities.py b/collective_effects/instabilities.py
index 9ac6622..83089cc 100644
--- a/collective_effects/instabilities.py
+++ b/collective_effects/instabilities.py
@@ -6,7 +6,6 @@ General calculations about instabilities
 """
 
 import numpy as np
-import matplotlib.pyplot as plt
 from scipy.constants import c, m_e, e, pi, epsilon_0
 import math
 
@@ -204,57 +203,6 @@ def lcbi_growth_rate(ring, I, Vrf, M, fr=None, Rs=None, QL=None, Z=None):
     mu = np.argmax(growth_rates)
     
     return growth_rate, mu, growth_rates
-
-def plot_critical_mass(ring, bunch_charge, bunch_spacing, n_points=1e4):
-    """
-    Plot ion critical mass, using Eq. (7.70) p147 of [1]
-
-    Parameters
-    ----------
-    ring : Synchrotron object
-    bunch_charge : float
-        Bunch charge in [C].
-    bunch_spacing : float
-        Time in between two adjacent bunches in [s].
-    n_points : float or int, optional
-        Number of point used in the plot. The default is 1e4.
-
-    Returns
-    -------
-    fig : figure
-        
-    References
-    ----------
-    [1] : Gamelin, A. (2018). Collective effects in a transient microbunching 
-    regime and ion cloud mitigation in ThomX (Doctoral dissertation, 
-    Université Paris-Saclay).
-
-    """
-    
-    n_points = int(n_points)
-    s = np.linspace(0, ring.L, n_points)
-    sigma = ring.sigma(s)
-    rp = 1.534698250004804e-18 # Proton classical radius, m
-    N = np.abs(bunch_charge/e)
-    
-    Ay = N*rp*bunch_spacing*c/(2*sigma[2,:]*(sigma[2,:] + sigma[0,:]))
-    Ax = N*rp*bunch_spacing*c/(2*sigma[0,:]*(sigma[2,:] + sigma[0,:]))
-    
-    fig = plt.figure()
-    ax = plt.gca()
-    ax.plot(s, Ax, label=r"$A_x^c$")
-    ax.plot(s, Ay, label=r"$A_y^c$")
-    ax.set_yscale("log")
-    ax.plot(s, np.ones_like(s)*2, label=r"$H_2^+$")
-    ax.plot(s, np.ones_like(s)*16, label=r"$H_2O^+$")
-    ax.plot(s, np.ones_like(s)*18, label=r"$CH_4^+$")
-    ax.plot(s, np.ones_like(s)*28, label=r"$CO^+$")
-    ax.plot(s, np.ones_like(s)*44, label=r"$CO_2^+$")
-    ax.legend()
-    ax.set_ylabel("Critical mass")
-    ax.set_xlabel("Longitudinal position [m]")
-    
-    return fig
     
 def rwmbi_growth_rate(ring, current, beff, rho_material, plane='x'):
     """
@@ -327,33 +275,4 @@ def rwmbi_threshold(ring, beff, rho_material, plane='x'):
     Ith = (4*np.pi*E0*beff**3) / (c*beta0*tau_rad) * (((1-frac_tune)*omega0) / (2*c*Z0*rho_material))**0.5
     
     return Ith
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
\ No newline at end of file
+       
\ No newline at end of file
diff --git a/vlasov/__init__.py b/vlasov/__init__.py
index d79ec51..45a0308 100644
--- a/vlasov/__init__.py
+++ b/vlasov/__init__.py
@@ -5,4 +5,4 @@ Created on Tue Jan 14 18:11:33 2020
 @author: gamelina
 """
 
-from mbtrack2.vlasov.beamloading import BeamLoadingVlasov
\ No newline at end of file
+from mbtrack2.vlasov.beamloading import BeamLoadingEquilibrium
\ No newline at end of file
diff --git a/vlasov/beamloading.py b/vlasov/beamloading.py
index a91b5b2..87aea1a 100644
--- a/vlasov/beamloading.py
+++ b/vlasov/beamloading.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 """
-Beam loading module
+Beam loading equilibrium module
 Created on Fri Aug 23 13:32:03 2019
 
 @author: gamelina
@@ -8,18 +8,11 @@ Created on Fri Aug 23 13:32:03 2019
 
 import numpy as np
 import matplotlib.pyplot as plt
-import matplotlib.patches as patches
-import matplotlib.path as mpltPath
-from mpl_toolkits.mplot3d import Axes3D
-import sys
-from mpi4py import MPI
-from scipy.optimize import root, fsolve
+from scipy.optimize import root
 from scipy.constants import c
-from scipy.integrate import solve_ivp, quad, romb
-from scipy.interpolate import interp1d, griddata
-from scipy import real, imag
+from scipy.integrate import quad
 
-class BeamLoadingVlasov():
+class BeamLoadingEquilibrium():
     """Class used to compute beam equilibrium profile and stability for a given
     storage ring and a list of RF cavities of any harmonic. The class assumes
     an uniform filling of the storage ring. Based on an extension of [1].
@@ -295,509 +288,3 @@ class BeamLoadingVlasov():
             self.plot_rho(self.B1 / 4, self.B2 / 4)
 
         return sol
-
-    def H0(self, z):
-        """Unperturbed Hamiltonian for delta = 0"""
-        return self.u(z)*self.ring.ac*c*self.ring.sigma_delta**2
-
-    def H0_val(self, z, value):
-        return self.H0(value) - self.H0(z)
-    
-    def Jfunc(self, z, zri_val):
-        """Convenience function to compute the J integral"""
-        return np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
-                       - 2*self.u(z)*self.ring.sigma_delta**2)
-
-    def Tsfunc(self, z, zri_val):
-        """Convenience function to compute the Ts integral"""
-        return 1/np.sqrt(2*self.H0(zri_val)/(self.ring.ac*c)
-                       - 2*self.u(z)*self.ring.sigma_delta**2)
-        
-    def dz_dt(self, delta):
-        """dz/dt part of the equation of motion for synchrotron oscillation"""
-        return self.ring.ac*c*delta
-    
-    def ddelta_dt(self, z):
-        """ddelta/dt part of the equation of motion for synchrotron oscillation"""
-        VRF = 0
-        for i in range(self.n_cavity):
-            cav = self.cavity_list[i]
-            VRF += cav.VRF(z, self.I0, self.F[i],self.PHI[i])
-        return (VRF - self.ring.U0)/(self.ring.E0*self.ring.T0)
-            
-    def EOMsystem(self, t, y):
-        """System of equations of motion for synchrotron oscillation
-        
-        Parameters
-        ----------
-        t : time
-        y[0] : z
-        y[1] : delta
-        """
-        
-        res = np.zeros(2,)
-        res[0] = self.dz_dt(y[1])
-        res[1] = self.ddelta_dt(y[0])
-        return res
-    
-    def cannonical_transform(self, zmax = 4.5e-2, tmax = 2e-3, nz = 1000, 
-                             ntime = 1000, plot = False, epsilon = 1e-10,
-                             rtol = 1e-9, atol = 1e-13, epsrel = 1e-9,
-                             epsabs = 1e-11, zri0_coef = 1e-2, n_pow = 16,
-                             eps = 1e-9):
-        """Compute cannonical transform zeta, see [1] appendix C.
-
-        Parameters
-        ----------
-        zmax : maximum amplitude of the z-grid for the cannonical transform
-            computation
-        nz : number of point in the z-grid
-        tmax : maximum amplitude of the time-grid
-        ntime : number of point in the time-grid
-        plot : if True, plot the surface map of the connonical transform
-        epsilon : convergence criterium for the maximum value of uexp(zmax) 
-            and uexp(zmin)
-        rtol, atol : relative and absolute tolerances for the equation of 
-            motion solving
-        epsabs, epsrel : relative and absolute tolerances for the integration
-            of J and Ts
-        zri0_coef : lower boundary of the z-grid is given by dz * zri0_coef, to
-            be lowered if lower values of z/J are needed in the grid
-        n_pow : power of two used in the romb integration of J/Ts if the error
-            on J/Ts is bigger than 0.1 % or if the value is nan
-        eps : if the romb integration is used, the integration interval is
-            limited to [zli[i] + eps,zri[i] - eps]
-        """
-        
-        # step (i)
-        
-        sol = root(lambda z : self.H0_val(z,zmax), -zmax)
-        zmin = sol.x
-        
-        if sol.success == False:
-            raise Exception("Error, no solution for zmin was found : " + sol.message)
-        
-        if self.uexp(zmax) > epsilon or self.uexp(zmin) > epsilon:
-            raise Exception("Error, uexp(zmax) < epsilon or uexp(zmin) < epsilon, try to increase zmax")
-            
-        # step (ii)
-        
-        dz = (zmax-zmin)/nz
-        zri = np.arange(nz)*dz
-        zri[0] = dz*zri0_coef
-        zli = np.zeros((nz,))
-        
-        for i in range(nz):
-            func = lambda z : self.H0_val(z,zri[i])
-            zli[i], infodict, ier, mesg = fsolve(func,-zri[i], full_output = True)
-            if ier != 1:
-                print("Error at i = " + str(i) + " : " + mesg)
-                
-        # step (iii)
-        
-        J = np.zeros((nz,))
-        J_err = np.zeros((nz,))
-        Ts = np.zeros((nz,))
-        Ts_err = np.zeros((nz,))
-        
-        for i in range(0,nz):
-            J[i], J_err[i] = quad(lambda z : 1/np.pi*self.Jfunc(z,zri[i]), 
-                 zli[i], zri[i], epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
-            
-            if J_err[i]/J[i] > 0.001:
-                print("Error is bigger than 0.1% for i = " + str(i) 
-                    + ", J = " + str(J[i]) +
-                    ", relative error on J = " + str(J_err[i]/J[i]))
-                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
-                y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
-                J[i] = romb(y0,x0[1]-x0[0])
-                print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
-                
-            if np.isnan(J[i]):
-                print("J is nan for i = " + str(i) )
-                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
-                y0 = (lambda z : 1/np.pi*self.Jfunc(z,zri[i]))(x0)
-                J[i] = romb(y0,x0[1]-x0[0])
-                print("Using romb to compute the integral instead of quad : J = " + str(J[i]))
-                
-            Ts[i], Ts_err[i] = quad(lambda z : 2*self.Tsfunc(z,zri[i]) / 
-                (self.ring.ac*c), zli[i], zri[i],
-                epsabs = epsabs, epsrel = epsrel, limit = int(1e4))
-            
-            if Ts_err[i]/Ts[i] > 0.001:
-                print("Error is bigger than 0.1% for i = " + str(i) 
-                    + ", Ts = " + str(Ts[i]) +
-                    ", relative error on Ts = " + str(Ts_err[i]/Ts[i]))
-                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
-                y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
-                Ts[i] = romb(y0,x0[1]-x0[0])
-                print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
-                
-            if np.isnan(Ts[i]):
-                print("Ts is nan for i = " + str(i) )
-                x0 = np.linspace(zli[i] + eps,zri[i] - eps,2**n_pow+1)
-                y0 = (lambda z : 2*self.Tsfunc(z,zri[i])/(self.ring.ac*c))(x0)
-                Ts[i] = romb(y0,x0[1]-x0[0])
-                print("Using romb to compute the integral instead of quad : Ts = " + str(Ts[i]))
-                
-        Omegas = 2*np.pi/Ts
-        zri[0] = 0 # approximation to dz*zri0_coef
-        # step (iv)
-    
-        z_sol = np.zeros((nz,ntime))
-        delta_sol = np.zeros((nz,ntime))
-        phi_sol = np.zeros((nz,ntime))
-            
-        for i in range(nz):
-            y0 = (zri[i], 0)
-            tspan = (0, tmax)
-            sol = solve_ivp(self.EOMsystem, tspan, y0, rtol = rtol, atol = atol)
-            
-            time_base = np.linspace(tspan[0],tspan[1],ntime)
-            
-            fz = interp1d(sol.t,sol.y[0],kind='cubic') # cubic decreases K std but increases a bit K mean 
-            fdelta = interp1d(sol.t,sol.y[1],kind='cubic')
-        
-            z_sol[i,:] = fz(time_base)
-            delta_sol[i,:] = fdelta(time_base)
-            phi_sol[i,:] = Omegas[i]*time_base;
-            
-        # Plot the surface
-        
-        if plot:
-            fig = plt.figure()
-            ax = Axes3D(fig)
-            surf = ax.plot_surface(J, phi_sol.T, z_sol.T, rcount = 75,
-                                   ccount = 75, cmap =plt.cm.viridis,
-                                   linewidth=0, antialiased=True)
-            fig.colorbar(surf, shrink=0.5, aspect=5)
-            ax.set_xlabel('J [m]')
-            ax.set_ylabel('$\phi$ [rad]')
-            ax.set_zlabel('z [m]')
-        
-        # Check the results by computing the Jacobian matrix
-        
-        dz_dp = np.zeros((nz,ntime))
-        dz_dJ = np.zeros((nz,ntime))
-        dd_dp = np.zeros((nz,ntime))
-        dd_dJ = np.zeros((nz,ntime))
-        K = np.zeros((nz,ntime))
-        
-        for i in range(1,nz-1):
-            for j in range(1,ntime-1):
-                dz_dp[i,j] = (z_sol[i,j+1] - z_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
-                dz_dJ[i,j] = (z_sol[i+1,j] - z_sol[i-1,j])/(J[i+1] - J[i-1])
-                if (J[i+1] - J[i-1]) == 0 :
-                    print(i)
-                dd_dp[i,j] = (delta_sol[i,j+1] - delta_sol[i,j-1])/(phi_sol[i,j+1] - phi_sol[i,j-1])
-                dd_dJ[i,j] = (delta_sol[i+1,j] - delta_sol[i-1,j])/(J[i+1] - J[i-1])
-        
-            
-        K = np.abs(dz_dp*dd_dJ - dd_dp*dz_dJ)
-        K[0,:] = K[1,:]
-        K[-1,:] = K[-2,:]
-        K[:,0] = K[:,1]
-        K[:,-1] = K[:,-2]
-        
-        print('Numerical Jacobian K: mean = ' + str(np.mean(K)) + ', value should be around 1.')
-        print('Numerical Jacobian K: std = ' + str(np.std(K)) + ', value should be around 0.')
-        print('Numerical Jacobian K: max = ' + str(np.max(K)) + ', value should be around 1.')
-        print('If the values are far from their nominal values, decrease atol, rtol, epsrel and epsabs.')
-        
-        # Plot the numerical Jacobian
-        
-        if plot:
-            fig = plt.figure()
-            ax = fig.gca(projection='3d')
-            surf = ax.plot_surface(J, phi_sol.T, K.T, rcount = 75, ccount = 75,
-                                   cmap =plt.cm.viridis,
-                                   linewidth=0, antialiased=True)
-            fig.colorbar(surf, shrink=0.5, aspect=5)
-            ax.set_xlabel('J [m]')
-            ax.set_ylabel('$\phi$ [rad]')
-            ax.set_zlabel('K')
-            
-        self.J = J
-        self.phi = phi_sol.T
-        self.z = z_sol.T
-        self.delta = delta_sol.T
-        self.Omegas = Omegas
-        self.Ts = Ts
-        
-        JJ, bla = np.meshgrid(J,J)
-        g = (JJ, phi_sol.T, z_sol.T)
-        J_p, phi_p, z_p = np.vstack(map(np.ravel, g))
-        
-        self.J_p = J_p
-        self.phi_p = phi_p
-        self.z_p = z_p
-        
-        # Provide interpolation functions for methods
-        self.func_J = interp1d(z_sol[:,0],J, bounds_error=True)
-        self.func_omegas = interp1d(J,Omegas, bounds_error=True)
-        
-        # Provide polygon to check if points are inside the map
-        p1 = [[J[i],phi_sol[i,:].max()] for i in range(J.size)]
-        p2 = [[J[i],phi_sol[i,:].min()] for i in range(J.size-1,0,-1)]
-        poly = p1+p2
-        self.path = mpltPath.Path(poly)
-        
-    def dphi0(self, z, delta):
-        """Partial derivative of the distribution function phi by z"""
-        dphi0 = -np.exp(-delta**2/(2*self.ring.sigma_delta**2)) / (np.sqrt(2*np.pi)*self.ring.sigma_delta) * self.rho(z)*self.du_dz(z)
-        return dphi0
-            
-    def zeta(self, J, phi):
-        """Compute zeta cannonical transformation z = zeta(J,phi) using 2d interpolation
-        interp2d and bisplrep does not seem to work for this case
-        griddata is working fine if method='linear'
-        griddata return nan if the point is outside the grid
-        """
-        z = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J, phi), method='linear')
-        if np.isnan(np.sum(z)) == True:
-            inside = self.path.contains_points(np.array([J, phi]).T)
-            if (np.sum(inside != True) >= 1):
-                print("Some points are outside the grid of the cannonical transformation zeta !")
-                print(str(np.sum(inside != True)) + " points are outside the grid out of " + str(J.size) + " points.")
-                plot_points_outside = False
-                if(plot_points_outside):
-                    fig = plt.figure()
-                    ax = fig.add_subplot(111)
-                    patch = patches.PathPatch(self.path, facecolor='orange', lw=2)
-                    ax.add_patch(patch)
-                    ax.set_xlim(self.J_p.min()*0.5,self.J_p.max()*1.5)
-                    ax.set_ylim(self.phi_p.min()*0.5,self.phi_p.max()*1.5)
-                    ax.scatter(J[np.nonzero(np.invert(inside))[0]],phi[np.nonzero(np.invert(inside))[0]])
-                    plt.show()
-                    vals = np.nonzero(np.invert(inside))[0]
-                    print("J min = " + str(J[vals].min()))
-                    print("J max = " + str(J[vals].max()))
-                    print("Phi min = " + str(phi[vals].min()))
-                    print("Phi max = " + str(phi[vals].max()))
-            else:
-                print("Some values interpolated from the zeta cannonical transformation are nan !")
-                print("This may be due to a scipy version < 1.3")
-            print("The nan values are remplaced by nearest values in the grid")
-            arr = np.isnan(z)
-            z[arr] = griddata(np.array([self.J_p,self.phi_p]).T, self.z_p, (J[arr], phi[arr]), method='nearest')
-        if np.isnan(np.sum(z)) == True:
-            print("Correction by nearest values in the grid failed, program is now stopping")
-            raise ValueError("zeta cannonical transformation are nan")
-        return z
-            
-    def H_coef(self, J_val, m, p, l, n_pow = 14):
-        """Compute H_{m,p} (J_val) using zeta cannonical transformation"""
-        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
-        omegap = self.ring.omega0*(p*self.ring.h + l)
-        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
-        res = np.zeros(J_val.shape, dtype=complex)
-        phi0_array = np.tile(phi0, J_val.size)
-        J_array = np.array([])
-        for i in range(J_val.size):
-            J_val_array = np.tile(J_val[i], phi0.size)
-            J_array = np.concatenate((J_array, J_val_array))
-        zeta_values = self.zeta(J_array, phi0_array)
-        y0 = np.exp(1j*m*phi0_array + 1j*omegap*zeta_values/c)
-        for i in range(J_val.size):
-            res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
-        return res
-    
-    def H_coef_star(self, J_val, m, p, l, n_pow = 14):
-        """Compute H_{m,p}^* (J_val) using zeta cannonical transformation"""
-        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
-        omegap = self.ring.omega0*(p*self.ring.h + l)
-        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
-        res = np.zeros(J_val.shape, dtype=complex)
-        phi0_array = np.tile(phi0, J_val.size)
-        J_array = np.array([])
-        for i in range(J_val.size):
-            J_val_array = np.tile(J_val[i], phi0.size)
-            J_array = np.concatenate((J_array, J_val_array))
-        zeta_values = self.zeta(J_array, phi0_array)
-        y0 = np.exp( - 1j*m*phi0_array - 1j*omegap*zeta_values/c)
-        for i in range(J_val.size):
-            res[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
-        return res
-    
-    def HH_coef(self, J_val, m, p, pp, l, n_pow = 10):
-        """Compute H_{m,p} (J_val) * H_{m,pp}^* (J_val) using zeta cannonical transformation"""
-        """quad does not work here, quadrature or romberg are possible if error estimate is needed but much slower"""
-        omegapH = self.ring.omega0*(p*self.ring.h + l)
-        omegapH_star = self.ring.omega0*(pp*self.ring.h + l)
-        phi0 = np.linspace(0,2*np.pi,2**n_pow+1)
-        H = np.zeros(J_val.shape, dtype=complex)
-        H_star = np.zeros(J_val.shape, dtype=complex)
-        phi0_array = np.tile(phi0, J_val.size)
-        J_array = np.array([])
-        for i in range(J_val.size):
-            J_val_array = np.tile(J_val[i], phi0.size)
-            J_array = np.concatenate((J_array, J_val_array))
-        zeta_values = self.zeta(J_array, phi0_array)
-        y0 = np.exp( 1j*m*phi0_array + 1j*omegapH*zeta_values/c)
-        y1 = np.exp( - 1j*m*phi0_array - 1j*omegapH_star*zeta_values/c)
-        for i in range(J_val.size):
-            H[i] = romb(y0[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
-            H_star[i] = romb(y1[i*phi0.size:(i+1)*phi0.size],phi0[1]-phi0[0])
-        return H*H_star
-    
-    def J_from_z(self, z):
-        """Return the action J corresponding to a given amplitude z_r,
-        corresponds to the inversion of z_r = zeta(J,0) : J = zeta^-1(z)
-        """
-        return self.func_J(z)
-    
-    def Omegas_from_z(self,z):
-        """Return the synchrotron angular frequency omega_s for a given amplitude z_r"""
-        return self.func_omegas(self.J_from_z(z))
-    
-    def G(self,m,p,pp,l,omega,delta_val = 0, n_pow = 8, Gmin = 1e-8, Gmax = 4.5e-2):
-        """Compute the G integral"""
-        g_func = lambda z : self.HH_coef(self.J_from_z(z), m, p, pp, l)*self.dphi0(z,delta_val)/(omega - m*self.Omegas_from_z(z))
-        z0 = np.linspace(Gmin,Gmax,2**n_pow+1)
-        y0 = g_func(z0)
-        G_val = romb(y0,z0[1]-z0[0])
-        #if( np.abs(y0[0]/G_val) > 0.001 or np.abs(y0[-1]/G_val) > 0.001 ):
-            #print("Integration boundaries for G value might be wrong.")
-            #plt.plot(z0,y0)
-        return G_val
-    
-    def mpi_init(self):
-        """Switch on mpi"""
-        
-        self.mpi = True
-        comm = MPI.COMM_WORLD
-        rank = comm.Get_rank()
-                
-        if(rank == 0):
-            pass
-        else:
-            while(True):
-                order = comm.bcast(None,0)
-
-                if(order == "init"):
-                    VLASOV = comm.bcast(None,0)
-                
-                if(order == "B_matrix"):
-                    Bsize = comm.bcast(None,0)
-                    if Bsize**2 != comm.size - 1:
-                        #raise ValueError("The number of processor must be Bsize**2 + 1, which is :",Bsize**2 + 1)
-                        #sys.exit()
-                        pass
-                    omega = comm.bcast(None,0)
-                    mmax = comm.bcast(None,0)
-                    l_solve = comm.bcast(None,0)
-                    
-                    Ind_table = np.zeros((2,Bsize)) # m, p
-                    for i in range(VLASOV.n_cavity):
-                        cav = VLASOV.cavity_list[i]
-                        Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
-                        Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
-                        Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
-                        Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
-                    
-                    matrix_i = np.zeros((Bsize,Bsize))
-                    matrix_j = np.zeros((Bsize,Bsize))
-                    for i in range(Bsize):
-                        for j in range(Bsize):
-                            matrix_i[i,j] = i
-                            matrix_j[i,j] = j
-                    
-                    i = int(matrix_i.flatten()[rank-1])
-                    j = int(matrix_j.flatten()[rank-1])
-                    
-                    B = np.zeros((Bsize,Bsize), dtype=complex)
-                    
-                    omegap = VLASOV.ring.omega0*(Ind_table[1,j]*VLASOV.ring.h + l_solve)
-                    Z = np.zeros((1,),dtype=complex)
-                    for k in range(VLASOV.n_cavity):
-                        cav = VLASOV.cavity_list[k]
-                        Z += cav.Z(omegap + omega)
-                    if i == j:
-                        B[i,j] += 1
-                    B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*VLASOV.I0/VLASOV.ring.E0/VLASOV.ring.T0*c*Z/omegap*VLASOV.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
-                    
-                    comm.Reduce([B, MPI.COMPLEX], None, op=MPI.SUM, root=0)
-
-                if(order == "stop"):
-                    sys.exit()               
-                    
-    def mpi_exit(self):
-        """Switch off mpi"""
-        
-        self.mpi = False
-        comm = MPI.COMM_WORLD
-        rank = comm.Get_rank()
-        
-        if(rank == 0):
-            comm.bcast("stop",0)
-    
-    def detB(self,omega,mmax,l_solve):
-        """Return the determinant of the matrix B"""
-        Bsize = 2*self.n_cavity*(2*mmax + 1)
-
-        if(self.mpi):
-            comm = MPI.COMM_WORLD
-            comm.bcast("B_matrix",0)
-            comm.bcast(Bsize,0)
-            comm.bcast(omega,0)
-            comm.bcast(mmax,0)
-            comm.bcast(l_solve,0)
-            B = np.zeros((Bsize,Bsize), dtype=complex)
-            comm.Reduce([np.zeros((Bsize,Bsize), dtype=complex), MPI.COMPLEX], [B, MPI.COMPLEX],op=MPI.SUM, root=0)
-        else:
-            Ind_table = np.zeros((2,Bsize)) # m, p
-            for i in range(self.n_cavity):
-                cav = self.cavity_list[i]
-                Ind_table[0,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = np.arange(-mmax,mmax+1)
-                Ind_table[0,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = np.arange(-mmax,mmax+1)
-                Ind_table[1,(2*mmax + 1)*2*i:(2*mmax + 1)*(2*i + 1)] = - cav.m
-                Ind_table[1,(2*mmax + 1)*(2*i + 1):(2*mmax + 1)*(2*i+2)] = cav.m
-            B = np.eye(Bsize, dtype=complex)
-            for i in range(Bsize):
-                for j in range(Bsize):
-                    omegap = self.ring.omega0*(Ind_table[1,j]*self.ring.h + l_solve)
-                    Z = np.zeros((1,),dtype=complex)
-                    for k in range(self.n_cavity):
-                        cav = self.cavity_list[k]
-                        Z += cav.Z(omegap + omega)
-                    B[i,j] +=  2*np.pi*1j*Ind_table[1,i]*self.I0/self.ring.E0/self.ring.T0*c*Z/omegap*self.G(Ind_table[0,i],Ind_table[1,i],Ind_table[1,j],l_solve,omega)
-        return np.linalg.det(B)
-    
-    def solveB(self,omega0,mmax,l_solve, maxfev = 200):
-        """Solve equation detB = 0
-        
-        Parameters
-        ----------
-        omega0 : initial guess with omega0[0] the real part of the solution and omega0[1] the imaginary part
-        mmax : maximum absolute value of m, see Eq. 20 of [1]
-        l_solve : instability coupled-bunch mode number
-        maxfev : the maximum number of calls to the function
-        
-        Returns
-        -------
-        omega : solution
-        infodict : dictionary of scipy.optimize.fsolve ouput
-        ier : interger flag, set to 1 if a solution was found
-        mesg : if no solution is found, mesg details the cause of failure
-        """
-        
-        def func(omega):
-            res = self.detB(omega[0] + 1j*omega[1],mmax,l_solve)
-            return real(res), imag(res)
-
-        if not self.mpi:
-            omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
-        elif self.mpi and MPI.COMM_WORLD.rank == 0:
-            comm = MPI.COMM_WORLD
-            comm.bcast("init",0)
-            comm.bcast(self, 0)
-            omega, infodict, ier, mesg = fsolve(func,omega0, maxfev = maxfev, full_output = True)
-        
-        if ier != 1:
-            print("The algorithm has not converged: " + str(ier))
-            print(mesg)
-            
-        return omega, infodict, ier, mesg
-
-        
-        
-- 
GitLab