diff --git a/Tracking/one_turn_matrix.py b/Tracking/one_turn_matrix.py index eea7ca39c1767fa1a0fc1dd6d9c54f0af72b1dbd..bdf74c7567ec3ef51d417f4e521c327e3d272450 100644 --- a/Tracking/one_turn_matrix.py +++ b/Tracking/one_turn_matrix.py @@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta): """ @wraps(function) def wrapper(*args, **kwargs): - for bunch in args[0]: - function(bunch, *args[1:], **kwargs) + self = args[0] + beam = args[1] + if (self.ring.mpi == True): + rank = self.ring.mpi_init() + function(self, beam[rank], *args[2:], **kwargs) + else: + for bunch in beam: + function(self, bunch, *args[2:], **kwargs) return wrapper @staticmethod @@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta): """ @wraps(function) def wrapper(*args, **kwargs): - for bunch in args[0].not_empty: - function(bunch, *args[1:], **kwargs) + self = args[0] + beam = args[1] + if (self.ring.mpi == True): + rank = self.ring.mpi_init() + function(self, beam[rank], *args[2:], **kwargs) + else: + for bunch in beam.not_empty: + function(self, bunch, *args[2:], **kwargs) return wrapper @@ -53,30 +65,24 @@ class Long_one_turn(Element): def __init__(self, ring): self.ring = ring - def track_bunch(self, bunch): + @Element.not_empty + def track(self, bunch): """track""" bunch["delta"] -= self.ring.U0/self.ring.E0 bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"] - def track(self, beam): - """track full beam""" - super().not_empty(self.track_bunch)(beam) - class SynchrotronRadiation(Element): """SyncRad""" def __init__(self, ring): self.ring = ring - def track_bunch(self, bunch): + @Element.not_empty + def track(self, bunch): """track""" rand = np.random.normal(size=len(bunch)) bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] + 2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand) - - def track(self, beam): - """track full beam""" - super().not_empty(self.track_bunch)(beam) class RF_cavity(Element): """ Perfect RF cavity class for main and harmonic RF cavities.""" @@ -86,13 +92,11 @@ class RF_cavity(Element): self.Vc = Vc # Amplitude of Cavity voltage [V] self.theta = theta # phase of Cavity voltage - def track_bunch(self,bunch): + @Element.not_empty + def track(self,bunch): """Track """ bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta) - def track(self, beam): - """track full beam""" - super().not_empty(self.track_bunch)(beam) class Trans_one_turn(Element): """ @@ -107,8 +111,9 @@ class Trans_one_turn(Element): self.disp = self.ring.mean_optics.disp self.dispp = self.ring.mean_optics.dispp self.phase_advance = self.ring.tune[0:2]*2*np.pi - - def track_bunch(self, bunch): + + @Element.not_empty + def track(self, bunch): """track""" phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"]) @@ -139,7 +144,4 @@ class Trans_one_turn(Element): bunch["xp"] = xp bunch["y"] = y bunch["yp"] = yp - - def track(self, beam): - """track full beam""" - super().not_empty(self.track_bunch)(beam) + diff --git a/Tracking/synchrotron.py b/Tracking/synchrotron.py index 60ab8438099a5e72b4a37aee26cae372d7c39984..fc99fcb7e49f9babd94011f3907308e295a3db4a 100644 --- a/Tracking/synchrotron.py +++ b/Tracking/synchrotron.py @@ -69,6 +69,8 @@ class Synchrotron: self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values + self.mpi = False + @property def h(self): """Harmonic number""" @@ -180,4 +182,15 @@ class Synchrotron: @property def eta(self): """Momentum compaction""" - return self.ac - 1/(self.gamma**2) \ No newline at end of file + return self.ac - 1/(self.gamma**2) + + def mpi_init(self): + try: + from mpi4py import MPI + except(ModuleNotFoundError): + print("mpi4py not found") + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + return rank + + \ No newline at end of file