From d8075bf9cd80c913e45fff5e7840c0184762667f Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 25 Jan 2023 14:44:43 +0100
Subject: [PATCH] Add TransverseMapSector element

Update the Optics class to be able to compute the phase advance at any location on the lattice.
Add the TransverseMapSector element to be able to split the lattice in several parts.
Add the transverse_map_sector_generator convenience function to split a lattice in several TransverseMapSector elements.
---
 mbtrack2/tracking/__init__.py |   4 +-
 mbtrack2/tracking/element.py  | 158 ++++++++++++++++++++++++++++++++++
 mbtrack2/utilities/optics.py  |  41 +++++++--
 3 files changed, 197 insertions(+), 6 deletions(-)

diff --git a/mbtrack2/tracking/__init__.py b/mbtrack2/tracking/__init__.py
index 9ce56b0..38a0f8d 100644
--- a/mbtrack2/tracking/__init__.py
+++ b/mbtrack2/tracking/__init__.py
@@ -12,7 +12,9 @@ from mbtrack2.tracking.element import (Element,
                                        LongitudinalMap, 
                                        TransverseMap, 
                                        SynchrotronRadiation,
-                                       SkewQuadrupole)
+                                       SkewQuadrupole,
+                                       TransverseMapSector,
+                                       transverse_map_sector_generator)
 from mbtrack2.tracking.aperture import (CircularAperture, 
                                         ElipticalAperture,
                                         RectangularAperture, 
diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index a9f2efd..a31c66e 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -243,3 +243,161 @@ class SkewQuadrupole:
         bunch['xp'] = bunch['xp'] - self.strength * bunch['y']
         bunch['yp'] = bunch['yp'] - self.strength * bunch['x']
 
+class TransverseMapSector(Element):
+    """
+    Transverse map for a sector of the synchrotron, from an initial 
+    position s0 to a final position s1.
+
+    Parameters
+    ----------
+    ring : Synchrotron object
+        Ring parameters.
+    alpha0 : array of shape (2,)
+        Alpha Twiss function at the initial location of the sector.
+    beta0 : array of shape (2,)
+        Beta Twiss function at the initial location of the sector.
+    dispersion0 : array of shape (4,)
+        Dispersion function at the initial location of the sector.
+    alpha1 : array of shape (2,)
+        Alpha Twiss function at the final location of the sector.
+    beta1 : array of shape (2,)
+        Beta Twiss function at the final location of the sector.
+    dispersion1 : array of shape (4,)
+        Dispersion function at the final location of the sector.
+    phase_diff : array of shape (2,)
+        Phase difference between the initial and final location of the 
+        sector.
+    chro_diff : array of shape (2,)
+        Chromaticity difference between the initial and final location of 
+        the sector.
+    adts : array of shape (4,), optional
+        Amplitude-dependent tune shift of the sector, see Synchrotron class 
+        for details. The default is None.
+
+    """
+    def __init__(self, ring, alpha0, beta0, dispersion0, alpha1, beta1, 
+                 dispersion1, phase_diff, chro_diff, adts=None):
+        self.ring = ring
+        self.alpha0 = alpha0
+        self.beta0 = beta0
+        self.dispersion0 = dispersion0
+        self.alpha1 = alpha1
+        self.beta1 = beta1
+        self.dispersion1 = dispersion1  
+        self.tune_diff = phase_diff / (2*np.pi)
+        self.chro_diff = chro_diff
+        if adts is not None:
+            self.adts_poly = [np.poly1d(adts[0]),
+                              np.poly1d(adts[1]),
+                              np.poly1d(adts[2]), 
+                              np.poly1d(adts[3])]
+        else:
+            self.adts_poly = None
+    
+    @Element.parallel    
+    def track(self, bunch):
+        """
+        Tracking method for the element.
+        No bunch to bunch interaction, so written for Bunch objects and
+        @Element.parallel is used to handle Beam objects.
+        
+        Parameters
+        ----------
+        bunch : Bunch or Beam object
+        """
+
+        # Compute phase advance which depends on energy via chromaticity and ADTS
+        if self.adts_poly is None:
+            phase_advance_x = 2*np.pi * (self.tune_diff[0] + 
+                                         self.chro_diff[0]*bunch["delta"])
+            phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
+                                         self.chro_diff[1]*bunch["delta"])
+        else:
+            phase_advance_x = 2*np.pi * (self.tune_diff[0] + 
+                                         self.chro_diff[0]*bunch["delta"] + 
+                                         self.adts_poly[0](bunch['x']) + 
+                                         self.adts_poly[2](bunch['y']))
+            phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
+                                         self.chro_diff[1]*bunch["delta"] +
+                                         self.adts_poly[1](bunch['x']) + 
+                                         self.adts_poly[3](bunch['y']))
+        
+        # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
+        matrix = np.zeros((6,6,len(bunch)))
+        
+        # Horizontal
+        matrix[0,0,:] = np.sqrt(self.beta1[0]/self.beta0[0])*(np.cos(phase_advance_x) + self.alpha0[0]*np.sin(phase_advance_x))
+        matrix[0,1,:] = np.sqrt(self.beta0[0]*self.beta1[0])*np.sin(phase_advance_x)
+        matrix[0,2,:] = self.dispersion1[0] - matrix[0,0,:]*self.dispersion0[0] - matrix[0,1,:]*self.dispersion0[1]
+        matrix[1,0,:] = ((self.alpha0[0] - self.alpha1[0])*np.cos(phase_advance_x) - (1 + self.alpha0[0]*self.alpha1[0])*np.sin(phase_advance_x))/np.sqrt(self.beta0[0]*self.beta1[0])
+        matrix[1,1,:] = np.sqrt(self.beta0[0]/self.beta1[0])*(np.cos(phase_advance_x) - self.alpha1[0]*np.sin(phase_advance_x))
+        matrix[1,2,:] = self.dispersion1[1] - matrix[1,0,:]*self.dispersion0[0] - matrix[1,1,:]*self.dispersion0[1]
+        matrix[2,2,:] = 1
+        
+        # Vertical
+        matrix[3,3,:] = np.sqrt(self.beta1[1]/self.beta0[1])*(np.cos(phase_advance_y) + self.alpha0[1]*np.sin(phase_advance_y))
+        matrix[3,4,:] = np.sqrt(self.beta0[1]*self.beta1[1])*np.sin(phase_advance_y)
+        matrix[3,5,:] = self.dispersion1[2] - matrix[3,3,:]*self.dispersion0[2] - matrix[3,4,:]*self.dispersion0[3]
+        matrix[4,3,:] = ((self.alpha0[1] - self.alpha1[1])*np.cos(phase_advance_y) - (1 + self.alpha0[1]*self.alpha1[1])*np.sin(phase_advance_y))/np.sqrt(self.beta0[1]*self.beta1[1])
+        matrix[4,4,:] = np.sqrt(self.beta0[1]/self.beta1[1])*(np.cos(phase_advance_y) - self.alpha1[1]*np.sin(phase_advance_y))
+        matrix[4,5,:] = self.dispersion1[3] - matrix[4,3,:]*self.dispersion0[2] - matrix[4,4,:]*self.dispersion0[3]
+        matrix[5,5,:] = 1
+        
+        x = matrix[0,0,:]*bunch["x"] + matrix[0,1,:]*bunch["xp"] + matrix[0,2,:]*bunch["delta"]
+        xp = matrix[1,0,:]*bunch["x"] + matrix[1,1,:]*bunch["xp"] + matrix[1,2,:]*bunch["delta"]
+        y =  matrix[3,3,:]*bunch["y"] + matrix[3,4,:]*bunch["yp"] + matrix[3,5,:]*bunch["delta"]
+        yp = matrix[4,3,:]*bunch["y"] + matrix[4,4,:]*bunch["yp"] + matrix[4,5,:]*bunch["delta"]
+        
+        bunch["x"] = x
+        bunch["xp"] = xp
+        bunch["y"] = y
+        bunch["yp"] = yp
+        
+def transverse_map_sector_generator(ring, positions):
+    """
+    Convenience function which generate a list of TransverseMapSector elements
+    from a ring with an AT lattice.
+    
+    Tracking through all the sectors is equivalent to a full turn (and thus to 
+    the TransverseMap object).
+
+    Parameters
+    ----------
+    ring : Synchrotron object
+        Ring parameters, must .
+    positions : array
+        List of longitudinal positions in [m] to use as starting and end points
+        of the TransverseMapSector elements.
+
+    Returns
+    -------
+    sectors : list
+        List of TransverseMapSector elements.
+
+    """
+    
+    if ring.optics.use_local_values:
+        raise ValueError("The Synchrotron object must be loaded from an AT lattice")
+    
+    N_sec = len(positions)
+    sectors = []
+    chro_diff = ring.chro/N_sec
+    for i in range(N_sec):
+        alpha0 = ring.optics.alpha(positions[i])
+        beta0 = ring.optics.beta(positions[i])
+        dispersion0 = ring.optics.dispersion(positions[i])
+        mu0 = ring.optics.mu(positions[i])
+        if i != (N_sec - 1):
+            alpha1 = ring.optics.alpha(positions[i+1])
+            beta1 = ring.optics.beta(positions[i+1])
+            dispersion1 = ring.optics.dispersion(positions[i+1])
+            mu1 = ring.optics.mu(positions[i+1])
+        else:
+            alpha1 = ring.optics.alpha(positions[0])
+            beta1 = ring.optics.beta(positions[0])
+            dispersion1 = ring.optics.dispersion(positions[0])
+            mu1 = ring.optics.mu(ring.L)
+        phase_diff = mu1 - mu0
+        sectors.append(TransverseMapSector(ring, alpha0, beta0, dispersion0, 
+                     alpha1, beta1, dispersion1, phase_diff, chro_diff))
+    return sectors
diff --git a/mbtrack2/utilities/optics.py b/mbtrack2/utilities/optics.py
index fdbd009..6552918 100644
--- a/mbtrack2/utilities/optics.py
+++ b/mbtrack2/utilities/optics.py
@@ -123,6 +123,7 @@ class Optics:
         self.beta_array = np.tile(twiss.beta.T, self.periodicity)
         self.alpha_array = np.tile(twiss.alpha.T, self.periodicity)
         self.dispersion_array = np.tile(twiss.dispersion.T, self.periodicity)
+        self.mu_array = np.tile(twiss.mu.T, self.periodicity)
         
         self.position = np.append(self.position, self.lattice.circumference)
         self.beta_array = np.append(self.beta_array, self.beta_array[:,0:1],
@@ -131,12 +132,17 @@ class Optics:
                                      axis=1)
         self.dispersion_array = np.append(self.dispersion_array,
                                           self.dispersion_array[:,0:1], axis=1)
+        self.mu_array = np.append(self.mu_array,
+                                  self.mu_array[:,0:1], axis=1)
         
         self.gamma_array = (1 + self.alpha_array**2)/self.beta_array
         self.tune = tune * self.periodicity
         self.chro = chrom * self.periodicity
         self.ac = at.get_mcf(self.lattice)
         
+        self.mu_array[:,-1] = (np.floor(self.mu_array[:,-2]/(2*np.pi)) +
+                               self.tune)*2*np.pi
+        
         self.setup_interpolation()
         
         
@@ -162,6 +168,10 @@ class Optics:
                               kind='linear')
         self.disppY = interp1d(self.position, self.dispersion_array[3,:],
                                kind='linear')
+        self.muX = interp1d(self.position, self.mu_array[0,:],
+                              kind='linear')
+        self.muY = interp1d(self.position, self.mu_array[1,:],
+                              kind='linear')
     
     @property
     def local_beta(self):
@@ -303,13 +313,34 @@ class Optics:
                           self.dispY(position), self.disppY(position)]
             return np.array(dispersion)
         
+    def mu(self, position):
+        """
+        Return phase advances at specific locations given by position. 
+        If no lattice has been loaded, None is returned.
+
+        Parameters
+        ----------
+        position : array or float
+            Longitudinal position at which the phase advances are returned.
+
+        Returns
+        -------
+        mu : array
+            Phase advances.
+        """
+        if self.use_local_values:
+            return np.outer(np.array([0,0]), np.ones_like(position))
+        else:
+            mu = [self.muX(position), self.muY(position)]
+            return np.array(mu)
+        
     def plot(self, var, option, n_points=1000):
         """
         Plot optical variables.
     
         Parameters
         ----------
-        var : {"beta", "alpha", "gamma", "dispersion"}
+        var : {"beta", "alpha", "gamma", "dispersion", "mu"}
             Optical variable to be plotted.
         option : str
             If var = "beta", "alpha" and "gamma", option = {"x","y"} specifying
@@ -322,7 +353,7 @@ class Optics:
         """
     
         var_dict = {"beta":self.beta, "alpha":self.alpha, "gamma":self.gamma, 
-                    "dispersion":self.dispersion}
+                    "dispersion":self.dispersion, "mu":self.mu}
         
         if var == "dispersion":
             option_dict = {"x":0, "px":1, "y":2, "py":3}
@@ -332,15 +363,15 @@ class Optics:
             ylabel = label[option_dict[option]]
          
         
-        elif var=="beta" or var=="alpha" or var=="gamma":
+        elif var=="beta" or var=="alpha" or var=="gamma" or var=="mu":
             option_dict = {"x":0, "y":1}
             label_dict = {"beta":"$\\beta$", "alpha":"$\\alpha$", 
-                          "gamma":"$\\gamma$"}
+                          "gamma":"$\\gamma$", "mu":"$\\mu$"}
             
             if option == "x": label_sup = "$_{x}$"
             elif option == "y": label_sup = "$_{y}$"
             
-            unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)"}
+            unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)", "mu":""}
             
             ylabel = label_dict[var] + label_sup + unit[var]
   
-- 
GitLab