Skip to content
Snippets Groups Projects
Commit a80ea9ab authored by Alexis GAMELIN's avatar Alexis GAMELIN
Browse files

[Fix] store MPI object in the Mpi class

parent e0e90ccb
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,6 @@ Module to handle parallel computation
"""
import numpy as np
from mpi4py import MPI
class Mpi:
"""
......@@ -52,7 +51,8 @@ class Mpi:
"""
def __init__(self, filling_pattern):
from mpi4py import MPI
self.MPI = MPI
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
......@@ -169,10 +169,10 @@ class Mpi:
bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n)
self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64))
self.comm.Allgather([center, MPI.DOUBLE], [self.__getattribute__(dim + "_center"), MPI.DOUBLE])
self.comm.Allgather([center, self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE])
self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64))
self.comm.Allgather([profile, MPI.INT64_T], [self.__getattribute__(dim + "_profile"), MPI.INT64_T])
self.comm.Allgather([profile, self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T])
self.__setattr__(dim + "_sorted_index", sorted_index)
......@@ -11,7 +11,6 @@ import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mbtrack2.tracking.parallel import Mpi
from scipy.constants import c, m_e, m_p, e
class Particle:
......@@ -679,6 +678,7 @@ class Beam:
def mpi_init(self):
"""Switch on MPI parallelisation and initialise a Mpi object"""
from mbtrack2.tracking.parallel import Mpi
self.mpi = Mpi(self.filling_pattern)
self.mpi_switch = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment