diff --git a/Tracking/one_turn_matrix.py b/Tracking/one_turn_matrix.py
index eea7ca39c1767fa1a0fc1dd6d9c54f0af72b1dbd..bdf74c7567ec3ef51d417f4e521c327e3d272450 100644
--- a/Tracking/one_turn_matrix.py
+++ b/Tracking/one_turn_matrix.py
@@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta):
         """
         @wraps(function)
         def wrapper(*args, **kwargs):
-            for bunch in args[0]:
-                function(bunch, *args[1:], **kwargs)
+            self = args[0]
+            beam = args[1]
+            if (self.ring.mpi == True):
+                rank = self.ring.mpi_init()
+                function(self, beam[rank], *args[2:], **kwargs)
+            else:
+                for bunch in beam:
+                    function(self, bunch, *args[2:], **kwargs)
         return wrapper
     
     @staticmethod
@@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta):
         """
         @wraps(function)
         def wrapper(*args, **kwargs):
-            for bunch in args[0].not_empty:
-                function(bunch, *args[1:], **kwargs)
+            self = args[0]
+            beam = args[1]
+            if (self.ring.mpi == True):
+                rank = self.ring.mpi_init()
+                function(self, beam[rank], *args[2:], **kwargs)
+            else:
+                for bunch in beam.not_empty:
+                    function(self, bunch, *args[2:], **kwargs)
         return wrapper
 
         
@@ -53,30 +65,24 @@ class Long_one_turn(Element):
     def __init__(self, ring):
         self.ring = ring
         
-    def track_bunch(self, bunch):
+    @Element.not_empty
+    def track(self, bunch):
         """track"""
         bunch["delta"] -= self.ring.U0/self.ring.E0
         bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"]
         
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
-        
 class SynchrotronRadiation(Element):
     """SyncRad"""
     
     def __init__(self, ring):
         self.ring = ring
         
-    def track_bunch(self, bunch):
+    @Element.not_empty        
+    def track(self, bunch):
         """track"""
         rand = np.random.normal(size=len(bunch))
         bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
              2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
-        
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
 
 class RF_cavity(Element):
     """ Perfect RF cavity class for main and harmonic RF cavities."""
@@ -86,13 +92,11 @@ class RF_cavity(Element):
         self.Vc = Vc # Amplitude of Cavity voltage [V]
         self.theta = theta # phase of Cavity voltage
         
-    def track_bunch(self,bunch):
+    @Element.not_empty    
+    def track(self,bunch):
         """Track """
         bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta)
         
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
         
 class Trans_one_turn(Element):
     """
@@ -107,8 +111,9 @@ class Trans_one_turn(Element):
         self.disp = self.ring.mean_optics.disp
         self.dispp = self.ring.mean_optics.dispp
         self.phase_advance = self.ring.tune[0:2]*2*np.pi
-        
-    def track_bunch(self, bunch):
+    
+    @Element.not_empty    
+    def track(self, bunch):
         """track"""
 
         phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"])
@@ -139,7 +144,4 @@ class Trans_one_turn(Element):
         bunch["xp"] = xp
         bunch["y"] = y
         bunch["yp"] = yp
-        
-    def track(self, beam):
-        """track full beam"""
-        super().not_empty(self.track_bunch)(beam)
+
diff --git a/Tracking/synchrotron.py b/Tracking/synchrotron.py
index 60ab8438099a5e72b4a37aee26cae372d7c39984..fc99fcb7e49f9babd94011f3907308e295a3db4a 100644
--- a/Tracking/synchrotron.py
+++ b/Tracking/synchrotron.py
@@ -69,6 +69,8 @@ class Synchrotron:
         self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
         self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
         
+        self.mpi = False
+        
     @property
     def h(self):
         """Harmonic number"""
@@ -180,4 +182,15 @@ class Synchrotron:
     @property
     def eta(self):
         """Momentum compaction"""
-        return self.ac - 1/(self.gamma**2)
\ No newline at end of file
+        return self.ac - 1/(self.gamma**2)
+    
+    def mpi_init(self):
+        try:
+            from mpi4py import MPI
+        except(ModuleNotFoundError):
+            print("mpi4py not found")
+        comm = MPI.COMM_WORLD
+        rank = comm.Get_rank()
+        return rank
+
+        
\ No newline at end of file