Skip to content
Snippets Groups Projects
Commit 1f8d3b69 authored by Gamelin Alexis's avatar Gamelin Alexis
Browse files

Mpi class for beam and mpi_gather function for beam

New parallel module to handle Mpi class which is used in Beam
parent 6b7a765e
Branches
Tags
No related merge requests found
......@@ -8,6 +8,7 @@ Beam and bunch elements
import numpy as np
import pandas as pd
from Tracking.parallel import Mpi
class Bunch:
"""
......@@ -257,7 +258,7 @@ class Beam:
def __init__(self, ring, bunch_list=None):
self.ring = ring
self.mpi_switch = False
if bunch_list is None:
self.init_beam(np.zeros((self.ring.h,1),dtype=bool))
else:
......@@ -348,7 +349,7 @@ class Beam:
"""Return an array with the filling pattern of the beam as bool"""
filling_pattern = []
for bunch in self:
if filling_pattern != 0:
if bunch.current != 0:
filling_pattern.append(True)
else:
filling_pattern.append(False)
......@@ -414,3 +415,24 @@ class Beam:
bunch_emit[:,index] = np.squeeze(bunch.emit)
return bunch_emit
def mpi_init(self):
""" Initialise mpi """
self.mpi = Mpi(self.filling_pattern)
self.mpi_switch = True
def mpi_gather(self):
"""Gather beam, all bunches of the different processors are sent to
all processors. Rather slow"""
if(self.mpi_switch == False):
print("Error, mpi is not initialised.")
bunch = self[self.mpi.bunch_num]
bunches = self.mpi.comm.allgather(bunch)
for rank in range(self.mpi.size):
self[self.mpi.rank_to_bunch(rank)] = bunches[rank]
\ No newline at end of file
......@@ -32,9 +32,8 @@ class Element(metaclass=ABCMeta):
def wrapper(*args, **kwargs):
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
if (beam.mpi_switch == True):
function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
else:
for bunch in beam:
function(self, bunch, *args[2:], **kwargs)
......@@ -48,9 +47,8 @@ class Element(metaclass=ABCMeta):
def wrapper(*args, **kwargs):
self = args[0]
beam = args[1]
if (self.ring.mpi == True):
rank = self.ring.mpi_init()
function(self, beam[rank], *args[2:], **kwargs)
if (beam.mpi_switch == True):
function(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
else:
for bunch in beam.not_empty:
function(self, bunch, *args[2:], **kwargs)
......
# -*- coding: utf-8 -*-
"""
Module to handle parallel computation
@author: Alexis Gamelin
@date: 06/03/2020
"""
import numpy as np
class Mpi:
""" Class which handle mpi
"""
def __init__(self, filling_pattern):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("Please install mpi4py module.")
else:
print("MPI error.")
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.write_table(filling_pattern)
def write_table(self, filling_pattern):
"""
Write a table with the rank and the corresponding bunch number for each
bunch of the filling pattern
"""
if(filling_pattern.sum() != self.size):
raise ValueError("The number of processors must be equal to the"
"number of (non-empty) bunches.")
table = np.zeros((self.size, 2), dtype = int)
table[:,0] = np.arange(0, self.size)
table[:,1] = np.where(filling_pattern)[0]
self.table = table
def rank_to_bunch(self, rank):
"""Return the bunch number corresponding to rank"""
return self.table[rank,1]
def bunch_to_rank(self, bunch_num):
"""Return the rank corresponding to the bunch number bunch_num"""
try:
rank = np.where(self.table[:,1] == bunch_num)[0][0]
except IndexError:
print("The bunch " + str(bunch_num) + " is not tracked on any processor.")
rank = None
return rank
@property
def bunch_num(self):
"""Return the bunch number corresponding to the current processor"""
return self.rank_to_bunch(self.rank)
\ No newline at end of file
......@@ -69,8 +69,6 @@ class Synchrotron:
self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
self.mean_optics = kwargs.get('mean_optics') # Optics object with mean values
self.mpi = False
@property
def h(self):
"""Harmonic number"""
......@@ -184,13 +182,3 @@ class Synchrotron:
"""Momentum compaction"""
return self.ac - 1/(self.gamma**2)
\ No newline at end of file
def mpi_init(self):
try:
from mpi4py import MPI
except(ModuleNotFoundError):
print("mpi4py not found")
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
return rank
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment