From a80ea9ab2425f67073c96657c3bcc1ee86d0823e Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Tue, 20 Jul 2021 16:34:28 +0200
Subject: [PATCH] [Fix] store MPI object in the Mpi class

---
 tracking/parallel.py  | 8 ++++----
 tracking/particles.py | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/tracking/parallel.py b/tracking/parallel.py
index 6de53e7..5145c50 100644
--- a/tracking/parallel.py
+++ b/tracking/parallel.py
@@ -8,7 +8,6 @@ Module to handle parallel computation
 """
 
 import numpy as np
-from mpi4py import MPI
 
 class Mpi:
     """
@@ -52,7 +51,8 @@ class Mpi:
     """
     
     def __init__(self, filling_pattern):
-
+        from mpi4py import MPI
+        self.MPI = MPI
         self.comm = MPI.COMM_WORLD
         self.rank = self.comm.Get_rank()
         self.size = self.comm.Get_size()
@@ -169,10 +169,10 @@ class Mpi:
             bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n)
             
             self.__setattr__(dim + "_center", np.empty((len(beam), len(center)), dtype=np.float64))
-            self.comm.Allgather([center,  MPI.DOUBLE], [self.__getattribute__(dim + "_center"), MPI.DOUBLE])
+            self.comm.Allgather([center,  self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE])
             
             self.__setattr__(dim + "_profile", np.empty((len(beam), len(profile)), dtype=np.int64))
-            self.comm.Allgather([profile,  MPI.INT64_T], [self.__getattribute__(dim + "_profile"), MPI.INT64_T])
+            self.comm.Allgather([profile,  self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T])
             
             self.__setattr__(dim + "_sorted_index", sorted_index)
     
diff --git a/tracking/particles.py b/tracking/particles.py
index fc75fa3..e7939e8 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -11,7 +11,6 @@ import numpy as np
 import matplotlib.pyplot as plt
 import seaborn as sns
 import pandas as pd
-from mbtrack2.tracking.parallel import Mpi
 from scipy.constants import c, m_e, m_p, e
 
 class Particle:
@@ -679,6 +678,7 @@ class Beam:
     
     def mpi_init(self):
         """Switch on MPI parallelisation and initialise a Mpi object"""
+        from mbtrack2.tracking.parallel import Mpi
         self.mpi = Mpi(self.filling_pattern)
         self.mpi_switch = True
         
-- 
GitLab