From e94d5e5f31bb646d7f00c9f92caa6df1fbfb0c19 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Thu, 29 Apr 2021 19:22:30 +0200
Subject: [PATCH] Add a mpi mode to Beam.init_beam

This avoid to store a huge amount of data in all cores before calling mpi_init to share it to all the cores.
---
 tracking/particles.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/tracking/particles.py b/tracking/particles.py
index 4e51537..1165ee0 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -524,7 +524,7 @@ class Beam:
         self._distance_between_bunches =  distance
         
     def init_beam(self, filling_pattern, current_per_bunch=1e-3, 
-                  mp_per_bunch=1e3, track_alive=True):
+                  mp_per_bunch=1e3, track_alive=True, mpi=False):
         """
         Initialize beam with a given filling pattern and marco-particle number 
         per bunch. Then initialize the different bunches with a 6D gaussian
@@ -548,12 +548,19 @@ class Beam:
             If False, the code no longer take into account alive/dead particles.
             Should be set to True if element such as apertures are used.
             Can be set to False to gain a speed increase.
+        mpi : bool, optional
+            If True, only a single bunch is fully initialized on each core, the
+            other bunches are initialized with a single marco-particle.
         """
         
         if (len(filling_pattern) != self.ring.h):
             raise ValueError(("The length of filling pattern is {} ".format(len(filling_pattern)) + 
                               "but should be {}".format(self.ring.h)))
         
+        if mpi is True:
+            mp_per_bunch_mpi = mp_per_bunch
+            mp_per_bunch = 1
+        
         filling_pattern = np.array(filling_pattern)
         bunch_list = []
         if filling_pattern.dtype == np.dtype("bool"):
@@ -577,8 +584,15 @@ class Beam:
         self.update_filling_pattern()
         self.update_distance_between_bunches()
         
-        for bunch in self.not_empty:
+        if mpi is True:
+            self.mpi_init()
+            current = self[self.mpi.rank_to_bunch(self.mpi.rank)].current
+            bunch =  Bunch(self.ring, mp_per_bunch_mpi, current, track_alive)
             bunch.init_gaussian()
+            self[self.mpi.rank_to_bunch(self.mpi.rank)] = bunch
+        else:
+            for bunch in self.not_empty:
+                bunch.init_gaussian()
     
     def update_filling_pattern(self):
         """Update the beam filling pattern."""
-- 
GitLab