From 6b7a765e5f1a2a0102e10798eb84cc692ae479ab Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Fri, 6 Mar 2020 15:13:00 +0100
Subject: [PATCH] MPI tracking test

Test OK, to improve
---
 Tracking/one_turn_matrix.py | 50 +++++++++++++++++++------------------
 Tracking/synchrotron.py     | 15 ++++++++++-
 2 files changed, 40 insertions(+), 25 deletions(-)

diff --git a/Tracking/one_turn_matrix.py b/Tracking/one_turn_matrix.py
index eea7ca3..bdf74c7 100644
--- a/Tracking/one_turn_matrix.py
+++ b/Tracking/one_turn_matrix.py
@@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta):
         """
         @wraps(function)
         def wrapper(*args, **kwargs):
-            for bunch in args[0]:
-                function(bunch, *args[1:], **kwargs)
+            self = args[0]
+            beam = args[1]
+            if (self.ring.mpi == True):
+                rank = self.ring.mpi_init()
+                function(self, beam[rank], *args[2:], **kwargs)
+            else:
+                for bunch in beam:
+                    function(self, bunch, *args[2:], **kwargs)
         return wrapper
     
     @staticmethod
@@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta):
         """
         @wraps(function)
         def wrapper(*args, **kwargs):
-            for bunch in args[0].not_empty:
-                function(bunch, *args[1:], **kwargs)
+            self = args[0]
+            beam = args[1]
+            if (self.ring.mpi == True):
+                rank = self.ring.mpi_init()
+                function(self, beam[rank], *args[2:], **kwargs)
+            else:
+                for bunch in beam.not_empty:
+                    function(self, bunch, *args[2:], **kwargs)
         return wrapper
 
         
@@ -53,30 +65,24 @@ class Long_one_turn(Element):
     def __init__(self, ring):
         self.ring = ring
         
-    def track_bunch(self, bunch):
+    @Element.not_empty
+    def track(self, bunch):
         """track"""
         bunch["delta"] -= self.ring.U0/self.ring.E0
         bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"]
         
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
-        
 class SynchrotronRadiation(Element):
     """SyncRad"""
     
     def __init__(self, ring):
         self.ring = ring
         
-    def track_bunch(self, bunch):
+    @Element.not_empty        
+    def track(self, bunch):
         """track"""
         rand = np.random.normal(size=len(bunch))
         bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
              2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
-        
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
 
 class RF_cavity(Element):
     """ Perfect RF cavity class for main and harmonic RF cavities."""
@@ -86,13 +92,11 @@ class RF_cavity(Element):
         self.Vc = Vc # Amplitude of Cavity voltage [V]
         self.theta = theta # phase of Cavity voltage
         
-    def track_bunch(self,bunch):
+    @Element.not_empty    
+    def track(self,bunch):
         """Track """
         bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta)
         
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
         
 class Trans_one_turn(Element):
     """
@@ -107,8 +111,9 @@ class Trans_one_turn(Element):
         self.disp = self.ring.mean_optics.disp
         self.dispp = self.ring.mean_optics.dispp
         self.phase_advance = self.ring.tune[0:2]*2*np.pi
-        
-    def track_bunch(self, bunch):
+    
+    @Element.not_empty    
+    def track(self, bunch):
         """track"""
 
         phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"])
@@ -139,7 +144,4 @@ class Trans_one_turn(Element):
         bunch["xp"] = xp
         bunch["y"] = y
         bunch["yp"] = yp
-        
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
+
diff --git a/Tracking/synchrotron.py b/Tracking/synchrotron.py
index 60ab843..fc99fcb 100644
--- a/Tracking/synchrotron.py
+++ b/Tracking/synchrotron.py
@@ -69,6 +69,8 @@ class Synchrotron:
         self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
         self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
         
+        self.mpi = False
+        
     @property
     def h(self):
         """Harmonic number"""
@@ -180,4 +182,15 @@ class Synchrotron:
     @property
     def eta(self):
         """Momentum compaction"""
-        return self.ac - 1/(self.gamma**2)
\ No newline at end of file
+        return self.ac - 1/(self.gamma**2)
+    
+    def mpi_init(self):
+        try:
+            from mpi4py import MPI
+        except(ModuleNotFoundError):
+            print("mpi4py not found")
+        comm = MPI.COMM_WORLD
+        rank = comm.Get_rank()
+        return rank
+
+        
\ No newline at end of file
-- 
GitLab