Skip to content
Snippets Groups Projects
Commit 1f8d3b69 authored by Gamelin Alexis's avatar Gamelin Alexis
Browse files

Mpi class for beam and mpi_gather function for beam

New parallel module to handle Mpi class which is used in Beam
parent 6b7a765e
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ Beam and bunch elements
import numpy as np
import pandas as pd
from Tracking.parallel import Mpi
class Bunch:
"""
......@@ -257,7 +258,7 @@ class Beam:
def __init__(self, ring, bunch_list=None):
self.ring = ring
self.mpi_switch = False
if bunch_list is None:
self.init_beam(np.zeros((self.ring.h,1),dtype=bool))
else:
......@@ -348,7 +349,7 @@ class Beam:
"""Return an array with the filling pattern of the beam as bool"""
filling_pattern = []
for bunch in self:
if filling_pattern != 0:
if bunch.current != 0:
filling_pattern.append(True)
else:
filling_pattern.append(False)
......@@ -413,4 +414,25 @@ class Beam:
for index, bunch in enumerate(self):
bunch_emit[:,index] = np.squeeze(bunch.emit)
return bunch_emit
def mpi_init(self):
""" Initialise mpi """
self.mpi = Mpi(self.filling_pattern)
self.mpi_switch = True
def mpi_gather(self):
"""Gather beam, all bunches of the different processors are sent to
all processors. Rather slow"""
if(self.mpi_switch == False):
print("Error, mpi is not initialised.")
bunch = self[self.mpi.bunch_num]
bunches = self.mpi.comm.allgather(bunch)
for rank in range(self.mpi.size):
self[self.mpi.rank_to_bunch(rank)] = bunches[rank]
\ No newline at end of file
......@@ -32,9 +32,8 @@ class Element(metaclass=ABCMeta):
def wrapper(*args, **kwargs):
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
if (beam.mpi_switch == True):
function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
else:
for bunch in beam:
function(self, bunch, *args[2:], **kwargs)
......@@ -48,9 +47,8 @@ class Element(metaclass=ABCMeta):
def wrapper(*args, **kwargs):
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
if (beam.mpi_switch == True):
function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
else:
for bunch in beam.not_empty:
function(self, bunch, *args[2:], **kwargs)
......
# -*- coding: utf-8 -*-
"""
Module to handle parallel computation
@author: Alexis Gamelin
@date: 06/03/2020
"""
import numpy as np
class Mpi:
""" Class which handle mpi
"""
def __init__(self, filling_pattern):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("Please install mpi4py module.")
else:
print("MPI error.")
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.write_table(filling_pattern)
def write_table(self, filling_pattern):
"""
Write a table with the rank and the corresponding bunch number for each
bunch of the filling pattern
"""
if(filling_pattern.sum() != self.size):
raise ValueError("The number of processors must be equal to the"
"number of (non-empty) bunches.")
table = np.zeros((self.size, 2), dtype = int)
table[:,0] = np.arange(0, self.size)
table[:,1] = np.where(filling_pattern)[0]
self.table = table
def rank_to_bunch(self, rank):
"""Return the bunch number corresponding to rank"""
return self.table[rank,1]
def bunch_to_rank(self, bunch_num):
"""Return the rank corresponding to the bunch number bunch_num"""
try:
rank = np.where(self.table[:,1] == bunch_num)[0][0]
except IndexError:
print("The bunch " + str(bunch_num) + " is not tracked on any processor.")
rank = None
return rank
@property
def bunch_num(self):
"""Return the bunch number corresponding to the current processor"""
return self.rank_to_bunch(self.rank)
\ No newline at end of file
......@@ -68,9 +68,7 @@ class Synchrotron:
self.emit = kwargs.get('emit') # X/Y emittances in [m.rad]
self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
self.mpi = False
@property
def h(self):
"""Harmonic number"""
......@@ -183,14 +181,4 @@ class Synchrotron:
def eta(self):
"""Momentum compaction"""
return self.ac - 1/(self.gamma**2)
def mpi_init(self):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("mpi4py not found")
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return rank
\ No newline at end of file
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment