diff --git a/tracking/parallel.py b/tracking/parallel.py index 6de53e785970fca7d2e991a7b84bed72346be3f4..5145c50e591ce9cf07f44bce0d71aafb6a13f0de 100644 --- a/tracking/parallel.py +++ b/tracking/parallel.py @@ -8,7 +8,6 @@ Module to handle parallel computation """ import numpy as np -from mpi4py import MPI class Mpi: """ @@ -52,7 +51,8 @@ class Mpi: """ def __init__(self, filling_pattern): - + from mpi4py import MPI + self.MPI = MPI self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() @@ -169,10 +169,10 @@ class Mpi: bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n) self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64)) - self.comm.Allgather([center, MPI.DOUBLE], [self.__getattribute__(dim + "_center"), MPI.DOUBLE]) + self.comm.Allgather([center, self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE]) self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64)) - self.comm.Allgather([profile, MPI.INT64_T], [self.__getattribute__(dim + "_profile"), MPI.INT64_T]) + self.comm.Allgather([profile, self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T]) self.__setattr__(dim + "_sorted_index", sorted_index) diff --git a/tracking/particles.py b/tracking/particles.py index fc75fa34931457605ba18f8edf841fa1cf826c08..e7939e8444dfb05ced7623817993ad10db95aedc 100644 --- a/tracking/particles.py +++ b/tracking/particles.py @@ -11,7 +11,6 @@ import numpy as np import matplotlib.pyplot as plt import seaborn as sns import pandas as pd -from mbtrack2.tracking.parallel import Mpi from scipy.constants import c, m_e, m_p, e class Particle: @@ -679,6 +678,7 @@ class Beam: def mpi_init(self): """Switch on MPI parallelisation and initialise a Mpi object""" + from mbtrack2.tracking.parallel import Mpi self.mpi = Mpi(self.filling_pattern) self.mpi_switch = True