Skip to content
Snippets Groups Projects
Commit 6b7a765e authored by Gamelin Alexis's avatar Gamelin Alexis
Browse files

MPI tracking test

Test OK, to improve
parent cca6cfb1
No related branches found
No related tags found
No related merge requests found
...@@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta): ...@@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta):
""" """
@wraps(function) @wraps(function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for bunch in args[0]: self = args[0]
function(bunch, *args[1:], **kwargs) beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
else:
for bunch in beam:
function(self, bunch, *args[2:], **kwargs)
return wrapper return wrapper
@staticmethod @staticmethod
...@@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta): ...@@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta):
""" """
@wraps(function) @wraps(function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
for bunch in args[0].not_empty: self = args[0]
function(bunch, *args[1:], **kwargs) beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
else:
for bunch in beam.not_empty:
function(self, bunch, *args[2:], **kwargs)
return wrapper return wrapper
...@@ -53,30 +65,24 @@ class Long_one_turn(Element): ...@@ -53,30 +65,24 @@ class Long_one_turn(Element):
def __init__(self, ring): def __init__(self, ring):
self.ring = ring self.ring = ring
def track_bunch(self, bunch): @Element.not_empty
def track(self, bunch):
"""track""" """track"""
bunch["delta"] -= self.ring.U0/self.ring.E0 bunch["delta"] -= self.ring.U0/self.ring.E0
bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"] bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"]
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class SynchrotronRadiation(Element): class SynchrotronRadiation(Element):
"""SyncRad""" """SyncRad"""
def __init__(self, ring): def __init__(self, ring):
self.ring = ring self.ring = ring
def track_bunch(self, bunch): @Element.not_empty
def track(self, bunch):
"""track""" """track"""
rand = np.random.normal(size=len(bunch)) rand = np.random.normal(size=len(bunch))
bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] + bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand) 2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class RF_cavity(Element): class RF_cavity(Element):
""" Perfect RF cavity class for main and harmonic RF cavities.""" """ Perfect RF cavity class for main and harmonic RF cavities."""
...@@ -86,13 +92,11 @@ class RF_cavity(Element): ...@@ -86,13 +92,11 @@ class RF_cavity(Element):
self.Vc = Vc # Amplitude of Cavity voltage [V] self.Vc = Vc # Amplitude of Cavity voltage [V]
self.theta = theta # phase of Cavity voltage self.theta = theta # phase of Cavity voltage
def track_bunch(self,bunch): @Element.not_empty
def track(self,bunch):
"""Track """ """Track """
bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta) bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta)
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class Trans_one_turn(Element): class Trans_one_turn(Element):
""" """
...@@ -107,8 +111,9 @@ class Trans_one_turn(Element): ...@@ -107,8 +111,9 @@ class Trans_one_turn(Element):
self.disp = self.ring.mean_optics.disp self.disp = self.ring.mean_optics.disp
self.dispp = self.ring.mean_optics.dispp self.dispp = self.ring.mean_optics.dispp
self.phase_advance = self.ring.tune[0:2]*2*np.pi self.phase_advance = self.ring.tune[0:2]*2*np.pi
def track_bunch(self, bunch): @Element.not_empty
def track(self, bunch):
"""track""" """track"""
phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"]) phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"])
...@@ -139,7 +144,4 @@ class Trans_one_turn(Element): ...@@ -139,7 +144,4 @@ class Trans_one_turn(Element):
bunch["xp"] = xp bunch["xp"] = xp
bunch["y"] = y bunch["y"] = y
bunch["yp"] = yp bunch["yp"] = yp
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
...@@ -69,6 +69,8 @@ class Synchrotron: ...@@ -69,6 +69,8 @@ class Synchrotron:
self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
self.mpi = False
@property @property
def h(self): def h(self):
"""Harmonic number""" """Harmonic number"""
...@@ -180,4 +182,15 @@ class Synchrotron: ...@@ -180,4 +182,15 @@ class Synchrotron:
@property @property
def eta(self): def eta(self):
"""Momentum compaction""" """Momentum compaction"""
return self.ac - 1/(self.gamma**2) return self.ac - 1/(self.gamma**2)
\ No newline at end of file
def mpi_init(self):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("mpi4py not found")
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return rank
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment