From a80ea9ab2425f67073c96657c3bcc1ee86d0823e Mon Sep 17 00:00:00 2001 From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr> Date: Tue, 20 Jul 2021 16:34:28 +0200 Subject: [PATCH] [Fix] store MPI object in the Mpi class --- tracking/parallel.py | 8 ++++---- tracking/particles.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tracking/parallel.py b/tracking/parallel.py index 6de53e7..5145c50 100644 --- a/tracking/parallel.py +++ b/tracking/parallel.py @@ -8,7 +8,6 @@ Module to handle parallel computation """ import numpy as np -from mpi4py import MPI class Mpi: """ @@ -52,7 +51,8 @@ class Mpi: """ def __init__(self, filling_pattern): - + from mpi4py import MPI + self.MPI = MPI self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() @@ -169,10 +169,10 @@ class Mpi: bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n) self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64)) - self.comm.Allgather([center, MPI.DOUBLE], [self.__getattribute__(dim + "_center"), MPI.DOUBLE]) + self.comm.Allgather([center, self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE]) self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64)) - self.comm.Allgather([profile, MPI.INT64_T], [self.__getattribute__(dim + "_profile"), MPI.INT64_T]) + self.comm.Allgather([profile, self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T]) self.__setattr__(dim + "_sorted_index", sorted_index) diff --git a/tracking/particles.py b/tracking/particles.py index fc75fa3..e7939e8 100644 --- a/tracking/particles.py +++ b/tracking/particles.py @@ -11,7 +11,6 @@ import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd -from mbtrack2.tracking.parallel import Mpi from scipy.constants import c, m_e, m_p, e class Particle: @@ -679,6 +678,7 @@ class Beam: def mpi_init(self): """Switch on MPI parallelisation and initialise a Mpi object""" + from mbtrack2.tracking.parallel import Mpi self.mpi = Mpi(self.filling_pattern) self.mpi_switch = True -- GitLab