From f3296289fb89342db4aad904dc27e4e31a42e3b4 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Fri, 25 Mar 2022 15:43:41 +0100
Subject: [PATCH] Enforce fundamental theorem of beam loading for wake
 functions

Change wake function in CircularResistiveWall to enforce fundamental theorem of beam loading.
Rework Resonator class to set manually time and frequency arrays and enforce fundamental theorem of beam loading.
Change to WakeField.impedance_components and WakeField.wake_components to allow methods begining by W or Z in WakeField subclasses.
---
 collective_effects/resistive_wall.py | 119 ++++++++++++++++-----------
 collective_effects/resonator.py      |  92 ++++++++++-----------
 collective_effects/wakefield.py      |   6 +-
 3 files changed, 119 insertions(+), 98 deletions(-)

diff --git a/collective_effects/resistive_wall.py b/collective_effects/resistive_wall.py
index 1bb3854..c0ae255 100644
--- a/collective_effects/resistive_wall.py
+++ b/collective_effects/resistive_wall.py
@@ -62,6 +62,10 @@ class CircularResistiveWall(WakeField):
     exact : bool, optional
         If False, approxmiated formulas are used for the wake function 
         computations.
+    atol : float, optional
+        Absolute tolerance used to enforce fundamental theorem of beam loading
+        for the exact expression of the longitudinal wake function.
+        Default is 1e-20.
     
     References
     ----------
@@ -74,12 +78,15 @@ class CircularResistiveWall(WakeField):
 
     """
 
-    def __init__(self, time, frequency, length, rho, radius, exact=False):
+    def __init__(self, time, frequency, length, rho, radius, exact=False, 
+                 atol=1e-20):
         super().__init__()
         
         self.length = length
         self.rho = rho
         self.radius = radius
+        self.Z0 = mu_0*c
+        self.t0 = (2*self.rho*self.radius**2 / self.Z0)**(1/3) / c
         
         omega = 2*np.pi*frequency
         Z1 = length*(1 + np.sign(frequency)*1j)*rho/(
@@ -87,7 +94,7 @@ class CircularResistiveWall(WakeField):
         Z2 = c/omega*length*(1 + np.sign(frequency)*1j)*rho/(
                 np.pi*radius**3*skin_depth(frequency,rho))
         
-        Wl = self.LongitudinalWakeFunction(time, exact)
+        Wl = self.LongitudinalWakeFunction(time, exact, atol)
         Wt = self.TransverseWakeFunction(time, exact)
         
         Zlong = Impedance(variable = frequency, function = Z1, impedance_type='long')
@@ -104,12 +111,15 @@ class CircularResistiveWall(WakeField):
         super().append_to_model(Wxdip)
         super().append_to_model(Wydip)
         
-    def LongitudinalWakeFunction(self, time, exact=False):
+    def LongitudinalWakeFunction(self, time, exact=False, atol=1e-20):
         """
         Compute the longitudinal wake function of a circular resistive wall 
         using Eq. (22), or approxmiated expression Eq. (24), of [1]. The 
         approxmiated expression is valid if the time is large compared to the 
         characteristic time t0.
+        
+        If some time value is smaller than atol, then the fundamental theorem 
+        of beam loading is applied: Wl(0) = Wl(0+)/2.
 
         Parameters
         ----------
@@ -117,6 +127,10 @@ class CircularResistiveWall(WakeField):
             Time points where the wake function is evaluated in [s].
         exact : bool, optional
             If True, the exact expression is used. The default is False.
+        atol : float, optional
+            Absolute tolerance used to enforce fundamental theorem of beam loading
+            for the exact expression of the longitudinal wake function.
+            Default is 1e-20.
 
         Returns
         -------
@@ -130,30 +144,17 @@ class CircularResistiveWall(WakeField):
         and Methods in Physics Research Section A: Accelerators, Spectrometers, 
         Detectors and Associated Equipment 806 (2016): 221-230.
         """
-        
-        Z0 = mu_0*c
-        
+        wl = np.zeros_like(time)
+        idx1 = time < 0
+        wl[idx1] = 0
         if exact==True:
-            self.t0 = (2*self.rho*self.radius**2 / Z0)**(1/3) / c
-            factor = 4*Z0*c/(np.pi * self.radius**2) * self.length * -1
-            wl = np.zeros_like(time)
-            
-            for i, t in enumerate(time):
-                val, err = quad(lambda z:self.function(t, z), 0, np.inf)
-                if t < 0:
-                    wl[i] = 0
-                else:
-                    wl[i] = factor * ( np.exp(-t/self.t0) / 3 * 
-                                      np.cos( np.sqrt(3) * t / self.t0 )  
-                                      - np.sqrt(2) / np.pi * val )
+            idx2 = time > 20 * self.t0
+            idx3 = np.logical_not(np.logical_or(idx1,idx2))
+            wl[idx3] = self.__LongWakeExact(time[idx3], atol)
         else:
-            wl = (1/(4*np.pi * self.radius) 
-                  * np.sqrt(Z0 * self.rho / (c * np.pi) ) 
-                  / time**(3/2) ) * self.length
-            ind = np.isnan(wl)
-            wl[ind] = 0
-    
-        return -1*wl
+            idx2 = np.logical_not(idx1)
+        wl[idx2] = self.__LongWakeApprox(time[idx2])
+        return wl
     
     def TransverseWakeFunction(self, time, exact=False):
         """
@@ -183,35 +184,59 @@ class CircularResistiveWall(WakeField):
         and Methods in Physics Research Section A: Accelerators, Spectrometers, 
         Detectors and Associated Equipment 806 (2016): 221-230.
         """
-        
-        Z0 = mu_0*c
-        
+        wt = np.zeros_like(time)
+        idx1 = time < 0
+        wt[idx1] = 0
         if exact==True:
-            self.t0 = (2*self.rho*self.radius**2 / Z0)**(1/3) / c
-            factor = -1 * (8 * Z0 * c**2 * self.t0) / (np.pi * self.radius**4) * self.length
-            wt = np.zeros_like(time)
-            
-            for i, t in enumerate(time):
-                val, err = quad(lambda z:self.function2(t, z), 0, np.inf)
-                if t < 0:
-                    wt[i] = 0
-                else:
-                    wt[i] = factor * ( 1 / 12 * (-1 * np.exp(-t/self.t0) * 
+            idx2 = time > 20 * self.t0
+            idx3 = np.logical_not(np.logical_or(idx1,idx2))
+            wt[idx3] = self.__TransWakeExact(time[idx3])
+        else:
+            idx2 = np.logical_not(idx1)
+        wt[idx2] = self.__TransWakeApprox(time[idx2])
+        return wt
+    
+    def __LongWakeExact(self, time, atol):
+        wl = np.zeros_like(time)
+        factor = 4*self.Z0*c/(np.pi * self.radius**2) * self.length
+        for i, t in enumerate(time):
+            val, err = quad(lambda z:self.__function(t, z), 0, np.inf)
+            wl[i] = factor * ( np.exp(-t/self.t0) / 3 * 
+                              np.cos( np.sqrt(3) * t / self.t0 )  
+                              - np.sqrt(2) / np.pi * val )
+            if np.isclose(0, t, atol=atol):
+                wl[i] = wl[i]/2
+        return wl
+    
+    def __TransWakeExact(self, time):
+        wt = np.zeros_like(time)
+        factor = ((8 * self.Z0 * c**2 * self.t0) / (np.pi * self.radius**4) * 
+                  self.length)
+        for i, t in enumerate(time):
+            val, err = quad(lambda z:self.__function2(t, z), 0, np.inf)
+            wt[i] = factor * ( 1 / 12 * (-1 * np.exp(-t/self.t0) * 
                                       np.cos( np.sqrt(3) * t / self.t0 ) + 
                                       np.sqrt(3) * np.exp(-t/self.t0) * 
                                       np.sin( np.sqrt(3) * t / self.t0 ) ) -
                                       np.sqrt(2) / np.pi * val )
-        else:
-            wt = (1 / (np.pi * self.radius**3) * np.sqrt(Z0 * c * self.rho / np.pi) 
-                  / time**(1/2) * self.length * -1)
-            ind = np.isnan(wt)
-            wt[ind] = 0
-        return -1*wt
-        
-    def function(self, t, x):
+        return wt
+    
+    def __LongWakeApprox(self, t):
+        wl = - 1 * ( 1 / (4*np.pi * self.radius) * 
+                    np.sqrt(self.Z0 * self.rho / (c * np.pi) ) /
+                    t ** (3/2) ) * self.length
+        return wl
+    
+    def __TransWakeApprox(self, t):
+        wt = (1 / (np.pi * self.radius**3) *
+              np.sqrt(self.Z0 * c * self.rho / np.pi)
+              / t ** (1/2) * self.length)
+        return wt
+    
+    def __function(self, t, x):
         return ( (x**2 * np.exp(-1* (x**2) * t / self.t0) ) / (x**6 + 8) )
     
-    def function2(self, t, x):
+    def __function2(self, t, x):
         return ( (-1 * np.exp(-1* (x**2) * t / self.t0) ) / (x**6 + 8) )
     
 class Coating(WakeField):
diff --git a/collective_effects/resonator.py b/collective_effects/resonator.py
index f11948e..41168b3 100644
--- a/collective_effects/resonator.py
+++ b/collective_effects/resonator.py
@@ -12,27 +12,29 @@ from mbtrack2.collective_effects.wakefield import (WakeField, Impedance,
                                                    WakeFunction)
 
 class Resonator(WakeField):
-    def __init__(self, Rs, fr, Q, plane, n_wake=1e6, n_imp=1e6, imp_freq_lim=100e9):
+    def __init__(self, time, frequency, Rs, fr, Q, plane, atol=1e-20):
         """
         Resonator model WakeField element which computes the impedance and the 
         wake function in both longitudinal and transverse case.
 
         Parameters
         ----------
+        time : array of float
+            Time points where the wake function will be evaluated in [s].
+        frequency : array of float
+            Frequency points where the impedance will be evaluated in [Hz].
         Rs : float
             Shunt impedance in [ohm].
         fr : float
             Resonance frequency in [Hz].
         Q : float
             Quality factor.
-        plane : str
+        plane : str or list
             Plane on which the resonator is used: "long", "x" or "y".
-        n_wake : int or float, optional
-            Number of points used in the wake function.
-        n_imp : int or float, optional
-            Number of points used in the impedance.
-        imp_freq_lim : float, optional
-            Maximum frequency used in the impedance.
+        atol : float, optional
+            Absolute tolerance used to enforce fundamental theorem of beam 
+            loading for the exact expression of the longitudinal wake function.
+            Default is 1e-20.
             
         References
         ----------
@@ -45,62 +47,54 @@ class Resonator(WakeField):
         self.fr = fr
         self.wr = 2 * np.pi * self.fr
         self.Q = Q
-        self.n_wake = int(n_wake)
-        self.n_imp = int(n_imp)
-        self.imp_freq_lim = imp_freq_lim
-        self.plane = plane
-
-        self.timestop = round(np.log(1000)/self.wr*2*self.Q, 12)
-        
+        if isinstance(plane, str):
+            self.plane = [plane]
+        elif isinstance(plane, list):
+            self.plane = plane
+            
         if self.Q >= 0.5:
             self.Q_p = np.sqrt(self.Q**2 - 0.25)
         else:
             self.Q_p = np.sqrt(0.25 - self.Q**2)
-            
         self.wr_p = (self.wr*self.Q_p)/self.Q
         
-        if self.plane == "long":
-            
-            freq = np.linspace(start=1, stop=self.imp_freq_lim, num=self.n_imp)
-            imp = Impedance(variable=freq, 
-                            function=self.long_impedance(freq),
-                            impedance_type="long")
-            super().append_to_model(imp)
-            
-            time = np.linspace(start=0, stop=self.timestop, num=self.n_wake)
-            wake = WakeFunction(variable=time,
-                                function=self.long_wake_function(time),
-                                wake_type="long")
-            super().append_to_model(wake)
-            
-        elif self.plane == "x" or self.plane == "y" :
-            
-            freq = np.linspace(start=1, stop=self.imp_freq_lim, num=self.n_imp)
-            imp = Impedance(variable=freq, 
-                            function=self.transverse_impedance(freq),
-                            impedance_type=self.plane + "dip")
-            super().append_to_model(imp)
-            
-            time = np.linspace(start=0, stop=self.timestop, num=self.n_wake)
-            wake = WakeFunction(variable=time,
-                                function=self.transverse_wake_function(time),
-                                wake_type=self.plane + "dip")
-            super().append_to_model(wake)
-        else:
-            raise ValueError("Plane must be: long, x or y")
+        for dim in self.plane:
+            if dim == "long":
+                Zlong = Impedance(variable=frequency, 
+                                function=self.long_impedance(frequency),
+                                impedance_type="long")
+                super().append_to_model(Zlong)
+                Wlong = WakeFunction(variable=time,
+                                    function=self.long_wake_function(time, atol),
+                                    wake_type="long")
+                super().append_to_model(Wlong)
+                
+            elif dim == "x" or dim == "y":
+                Zdip = Impedance(variable=frequency, 
+                                function=self.transverse_impedance(frequency),
+                                impedance_type=dim + "dip")
+                super().append_to_model(Zdip)
+                Wdip = WakeFunction(variable=time,
+                                    function=self.transverse_wake_function(time),
+                                    wake_type=dim + "dip")
+                super().append_to_model(Wdip)
+            else:
+                raise ValueError("Plane must be: long, x or y")
         
-    def long_wake_function(self, t):
+    def long_wake_function(self, t, atol):
         if self.Q >= 0.5:
-            return ( (self.wr * self.Rs / self.Q) * 
+            wl = ( (self.wr * self.Rs / self.Q) * 
                     np.exp(-1* self.wr * t / (2 * self.Q) ) *
                      (np.cos(self.wr_p * t) - 
                       np.sin(self.wr_p * t) / (2 * self.Q_p) ) )
-                        
         elif self.Q < 0.5:
-            return ( (self.wr * self.Rs / self.Q) * 
+            wl = ( (self.wr * self.Rs / self.Q) * 
                     np.exp(-1* self.wr * t / (2 * self.Q) ) *
                      (np.cosh(self.wr_p * t) - 
                       np.sinh(self.wr_p * t) / (2 * self.Q_p) ) )
+        if np.any(np.abs(t) < atol):
+            wl[np.abs(t) < atol] = wl[np.abs(t) < atol]/2
+        return wl
                             
     def long_impedance(self, f):
         return self.Rs / (1 + 1j * self.Q * (f/self.fr - self.fr/f))
diff --git a/collective_effects/wakefield.py b/collective_effects/wakefield.py
index 36b37f2..3af5ea0 100644
--- a/collective_effects/wakefield.py
+++ b/collective_effects/wakefield.py
@@ -691,14 +691,16 @@ class WakeField:
         """
         Return an array of the impedance component names for the element.
         """
-        return np.array([comp for comp in dir(self) if re.match(r'[Z]', comp)])
+        valid = ["Zlong", "Zxdip", "Zydip", "Zxquad", "Zyquad"]
+        return np.array([comp for comp in dir(self) if comp in valid])
     
     @property
     def wake_components(self):
         """
         Return an array of the wake function component names for the element.
         """
-        return np.array([comp for comp in dir(self) if re.match(r'[W]', comp)])
+        valid = ["Wlong", "Wxdip", "Wydip", "Wxquad", "Wyquad"]
+        return np.array([comp for comp in dir(self) if comp in valid])
     
     @staticmethod
     def add_wakefields(wake1, beta1, wake2, beta2):
-- 
GitLab