Skip to content
Snippets Groups Projects
Commit 6b7a765e authored by Gamelin Alexis's avatar Gamelin Alexis
Browse files

MPI tracking test

Test OK, to improve
parent cca6cfb1
No related branches found
No related tags found
No related merge requests found
......@@ -30,8 +30,14 @@ class Element(metaclass=ABCMeta):
"""
@wraps(function)
def wrapper(*args, **kwargs):
for bunch in args[0]:
function(bunch, *args[1:], **kwargs)
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
else:
for bunch in beam:
function(self, bunch, *args[2:], **kwargs)
return wrapper
@staticmethod
......@@ -40,8 +46,14 @@ class Element(metaclass=ABCMeta):
"""
@wraps(function)
def wrapper(*args, **kwargs):
for bunch in args[0].not_empty:
function(bunch, *args[1:], **kwargs)
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
else:
for bunch in beam.not_empty:
function(self, bunch, *args[2:], **kwargs)
return wrapper
......@@ -53,30 +65,24 @@ class Long_one_turn(Element):
def __init__(self, ring):
self.ring = ring
def track_bunch(self, bunch):
@Element.not_empty
def track(self, bunch):
"""track"""
bunch["delta"] -= self.ring.U0/self.ring.E0
bunch["tau"] -= self.ring.ac*self.ring.T0*bunch["delta"]
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class SynchrotronRadiation(Element):
"""SyncRad"""
def __init__(self, ring):
self.ring = ring
def track_bunch(self, bunch):
@Element.not_empty
def track(self, bunch):
"""track"""
rand = np.random.normal(size=len(bunch))
bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class RF_cavity(Element):
""" Perfect RF cavity class for main and harmonic RF cavities."""
......@@ -86,13 +92,11 @@ class RF_cavity(Element):
self.Vc = Vc # Amplitude of Cavity voltage [V]
self.theta = theta # phase of Cavity voltage
def track_bunch(self,bunch):
@Element.not_empty
def track(self,bunch):
"""Track """
bunch["delta"] += self.Vc/self.ring.E0*np.cos(self.m*self.ring.omega1*bunch["tau"] + self.theta)
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
class Trans_one_turn(Element):
"""
......@@ -107,8 +111,9 @@ class Trans_one_turn(Element):
self.disp = self.ring.mean_optics.disp
self.dispp = self.ring.mean_optics.dispp
self.phase_advance = self.ring.tune[0:2]*2*np.pi
def track_bunch(self, bunch):
@Element.not_empty
def track(self, bunch):
"""track"""
phase_advance_x = self.phase_advance[0]*(1+self.ring.chro[0]*bunch["delta"])
......@@ -139,7 +144,4 @@ class Trans_one_turn(Element):
bunch["xp"] = xp
bunch["y"] = y
bunch["yp"] = yp
def track(self, beam):
"""track full beam"""
super().not_empty(self.track_bunch)(beam)
......@@ -69,6 +69,8 @@ class Synchrotron:
self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
self.mpi = False
@property
def h(self):
"""Harmonic number"""
......@@ -180,4 +182,15 @@ class Synchrotron:
@property
def eta(self):
"""Momentum compaction"""
return self.ac - 1/(self.gamma**2)
\ No newline at end of file
return self.ac - 1/(self.gamma**2)
def mpi_init(self):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("mpi4py not found")
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return rank
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment