Skip to content
Snippets Groups Projects
Commit a80ea9ab authored by Alexis GAMELIN's avatar Alexis GAMELIN
Browse files

[Fix] store MPI object in the Mpi class

parent e0e90ccb
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,6 @@ Module to handle parallel computation ...@@ -8,7 +8,6 @@ Module to handle parallel computation
""" """
import numpy as np import numpy as np
from mpi4py import MPI
class Mpi: class Mpi:
""" """
...@@ -52,7 +51,8 @@ class Mpi: ...@@ -52,7 +51,8 @@ class Mpi:
""" """
def __init__(self, filling_pattern): def __init__(self, filling_pattern):
from mpi4py import MPI
self.MPI = MPI
self.comm = MPI.COMM_WORLD self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank() self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size() self.size = self.comm.Get_size()
...@@ -169,10 +169,10 @@ class Mpi: ...@@ -169,10 +169,10 @@ class Mpi:
bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n) bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n)
self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64)) self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64))
self.comm.Allgather([center, MPI.DOUBLE], [self.__getattribute__(dim + "_center"), MPI.DOUBLE]) self.comm.Allgather([center, self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE])
self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64)) self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64))
self.comm.Allgather([profile, MPI.INT64_T], [self.__getattribute__(dim + "_profile"), MPI.INT64_T]) self.comm.Allgather([profile, self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T])
self.__setattr__(dim + "_sorted_index", sorted_index) self.__setattr__(dim + "_sorted_index", sorted_index)
...@@ -11,7 +11,6 @@ import numpy as np ...@@ -11,7 +11,6 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
import pandas as pd import pandas as pd
from mbtrack2.tracking.parallel import Mpi
from scipy.constants import c, m_e, m_p, e from scipy.constants import c, m_e, m_p, e
class Particle: class Particle:
...@@ -679,6 +678,7 @@ class Beam: ...@@ -679,6 +678,7 @@ class Beam:
def mpi_init(self): def mpi_init(self):
"""Switch on MPI parallelisation and initialise a Mpi object""" """Switch on MPI parallelisation and initialise a Mpi object"""
from mbtrack2.tracking.parallel import Mpi
self.mpi = Mpi(self.filling_pattern) self.mpi = Mpi(self.filling_pattern)
self.mpi_switch = True self.mpi_switch = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment