From 0166b60ac4118b80b76d1728075bb1e699204059 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Fri, 10 Feb 2023 16:15:03 +0100
Subject: [PATCH] [BugFix] ADTS

Fix ADTS implementation in TransverseMap and TransverseMapSector to use Jx/Jy instead of x/y.
Add Synchrotron.get_adts function which compute ADTS from AT lattice.
Change Synchrotron class docstring regarding ADTS.
---
 mbtrack2/tracking/element.py     | 30 ++++++++++++++++++++++--------
 mbtrack2/tracking/synchrotron.py | 27 ++++++++++++++++++++++-----
 2 files changed, 44 insertions(+), 13 deletions(-)

diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index d065078..81fbdb5 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -176,14 +176,20 @@ class TransverseMap(Element):
             phase_advance_y = 2*np.pi * (self.ring.tune[1] + 
                                          self.ring.chro[1]*bunch["delta"])
         else:
+            Jx = (self.ring.optics.local_gamma[0] * bunch['x']**2) + \
+                  (2*self.ring.optics.local_alpha[0] * bunch['x']*bunch['xp']) + \
+                  (self.ring.optics.local_beta[0] * bunch['xp']**2)
+            Jy = (self.ring.optics.local_gamma[1] * bunch['y']**2) + \
+                  (2*self.ring.optics.local_alpha[1] * bunch['y']*bunch['yp']) + \
+                  (self.ring.optics.local_beta[1] * bunch['yp']**2)
             phase_advance_x = 2*np.pi * (self.ring.tune[0] + 
                                          self.ring.chro[0]*bunch["delta"] + 
-                                         self.adts_poly[0](bunch['x']) + 
-                                         self.adts_poly[2](bunch['y']))
+                                         self.adts_poly[0](Jx) + 
+                                         self.adts_poly[2](Jy))
             phase_advance_y = 2*np.pi * (self.ring.tune[1] + 
                                          self.ring.chro[1]*bunch["delta"] +
-                                         self.adts_poly[1](bunch['x']) + 
-                                         self.adts_poly[3](bunch['y']))
+                                         self.adts_poly[1](Jx) + 
+                                         self.adts_poly[3](Jy))
         
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
         matrix = np.zeros((6,6,len(bunch)))
@@ -282,9 +288,11 @@ class TransverseMapSector(Element):
         self.ring = ring
         self.alpha0 = alpha0
         self.beta0 = beta0
+        self.gamma0 = (1 + self.alpha0**2)/self.beta0
         self.dispersion0 = dispersion0
         self.alpha1 = alpha1
         self.beta1 = beta1
+        self.gamma1 = (1 + self.alpha1**2)/self.beta1
         self.dispersion1 = dispersion1  
         self.tune_diff = phase_diff / (2*np.pi)
         self.chro_diff = chro_diff
@@ -315,14 +323,20 @@ class TransverseMapSector(Element):
             phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
                                          self.chro_diff[1]*bunch["delta"])
         else:
+            Jx = (self.gamma0[0] * bunch['x']**2) + \
+                  (2*self.alpha0[0] * bunch['x']*self['xp']) + \
+                  (self.beta0[0] * bunch['xp']**2)
+            Jy = (self.gamma0[1] * bunch['y']**2) + \
+                  (2*self.alpha0[1] * bunch['y']*bunch['yp']) + \
+                  (self.beta0[1] * bunch['yp']**2)
             phase_advance_x = 2*np.pi * (self.tune_diff[0] + 
                                          self.chro_diff[0]*bunch["delta"] + 
-                                         self.adts_poly[0](bunch['x']) + 
-                                         self.adts_poly[2](bunch['y']))
+                                         self.adts_poly[0](Jx) + 
+                                         self.adts_poly[2](Jy))
             phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
                                          self.chro_diff[1]*bunch["delta"] +
-                                         self.adts_poly[1](bunch['x']) + 
-                                         self.adts_poly[3](bunch['y']))
+                                         self.adts_poly[1](Jx) + 
+                                         self.adts_poly[3](Jy))
         
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
         matrix = np.zeros((6,6,len(bunch)))
diff --git a/mbtrack2/tracking/synchrotron.py b/mbtrack2/tracking/synchrotron.py
index 1746007..27ada22 100644
--- a/mbtrack2/tracking/synchrotron.py
+++ b/mbtrack2/tracking/synchrotron.py
@@ -4,6 +4,7 @@ Module where the Synchrotron class is defined.
 """
 
 import numpy as np
+import at
 from scipy.constants import c, e
         
 class Synchrotron:
@@ -47,13 +48,13 @@ class Synchrotron:
         The order of the elements strictly needs to be
         [coef_xx, coef_yx, coef_xy, coef_yy], where x and y denote the horizontal
         and the vertical plane, respectively, and coef_PQ means the polynomial's
-        coefficients of the ADTS in plane P due to the offset in plane Q.
+        coefficients of the ADTS in plane P due to the amplitude in plane Q.
 
-        For example, if the tune shift in y due to the offset x is characterized
-        by the equation dQy(x) = 3*x**2 + 2*x + 1, then coef_yx takes the form
-        np.array([3, 2, 1]).
+        For example, if the tune shift in y due to the Courant-Snyder invariant 
+        Jx is characterized by the equation dQy(x) = 3*Jx**2 + 2*Jx + 1, then 
+        coef_yx takes the form np.array([3, 2, 1]).
         
-        Use None, to exclude the ADTS calculation.
+        Use None, to exclude the ADTS from calculation.
         
     Attributes
     ----------
@@ -296,3 +297,19 @@ class Synchrotron:
         tuneS = np.sqrt( - (Vrf / self.E0) * (self.h * self.ac) / (2*np.pi) 
                         * np.cos(phase) )
         return tuneS
+    
+    def get_adts(self):
+        """
+        Compute and add Amplitude-Dependent Tune Shifts (ADTS) sextupolar 
+        componenet from AT lattice.
+        """
+        if self.optics.use_local_values:
+            raise ValueError("ADTS needs to be provided manualy as no AT" + 
+                             " lattice file is loaded.")
+            
+        det = at.physics.nonlinear.gen_detuning_elem(self.optics.lattice)
+        coef_xx = np.array([det.A1/2, 0])
+        coef_yx = np.array([det.A2/2, 0])
+        coef_xy = np.array([det.A2/2, 0])
+        coef_yy = np.array([det.A3/2, 0])
+        self.adts = [coef_xx, coef_yx, coef_xy, coef_yy]
-- 
GitLab