From 38aa2719113d932ccd42791255e9e09e6a5cda02 Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Thu, 12 Dec 2024 16:06:48 +0100
Subject: [PATCH] Adds Beam.init_bunch_list_mpi

Beam.update_filling_pattern and Beam.update_distance_between_bunches are now called directly in the __init__ when a bunch_list is given.
Beam.init_bunch_list_mpi is a convenience method to initialize a beam using MPI parallelisation with a Bunch per core.
Add tests for Beam.init_bunch_list_mpi
---
 mbtrack2/tracking/parallel.py        |  2 +-
 mbtrack2/tracking/particles.py       | 38 ++++++++++++++++++++++++++++
 tests/unit/tracking/test_particle.py | 23 ++++++++++++++++-
 3 files changed, 61 insertions(+), 2 deletions(-)

diff --git a/mbtrack2/tracking/parallel.py b/mbtrack2/tracking/parallel.py
index a441f2d..c921f61 100644
--- a/mbtrack2/tracking/parallel.py
+++ b/mbtrack2/tracking/parallel.py
@@ -81,7 +81,7 @@ class Mpi:
             Filling pattern of the beam, like Beam.filling_pattern
         """
         if (filling_pattern.sum() != self.size):
-            raise ValueError("The number of processors must be equal to the"
+            raise ValueError("The number of processors must be equal to the "
                              "number of (non-empty) bunches.")
         table = np.zeros((self.size, 2), dtype=int)
         table[:, 0] = np.arange(0, self.size)
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index 0deea66..cf872de 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -687,6 +687,8 @@ class Beam:
         Initialize beam with a given filling pattern and marco-particle number
         per bunch. Then initialize the different bunches with a 6D gaussian
         phase space.
+    init_bunch_list_mpi(bunch, filling_pattern)
+        Initialize a beam using MPI parallelisation with a Bunch per core.
     mpi_init()
         Switch on MPI parallelisation and initialise a Mpi object
     mpi_gather()
@@ -714,6 +716,8 @@ class Beam:
                 raise ValueError(("The length of the bunch list is {} ".format(
                     len(bunch_list)) + "but should be {}".format(self.ring.h)))
             self.bunch_list = bunch_list
+            self.update_filling_pattern()
+            self.update_distance_between_bunches()
 
     def __len__(self):
         """Return the number of (not empty) bunches"""
@@ -864,6 +868,40 @@ class Beam:
             for bunch in self.not_empty:
                 bunch.init_gaussian()
 
+    def init_bunch_list_mpi(self, bunch, filling_pattern):
+        """
+        Initialize a beam using MPI parallelisation with a Bunch per core.
+
+        Parameters
+        ----------
+        bunch : Bunch object
+            The bunch given should probably depend on the mpi.rank so that each 
+            core can track a different bunch.
+            Example: beam.init_bunch_list_mpi(bunch_list[comm.rank], filling_pattern)
+        filling_pattern : array-like of bool of length ring.h
+            Filling pattern of the beam as a list or an array of bool.
+
+        """
+        filling_pattern = np.array(filling_pattern)
+
+        if len(filling_pattern) != self.ring.h:
+            raise ValueError(("The length of filling pattern is {} ".format(
+                len(filling_pattern)) +
+                              "but should be {}".format(self.ring.h)))
+
+        if filling_pattern.dtype != np.dtype("bool"):
+            raise TypeError("dtype {} should be bool.".format(
+                filling_pattern.dtype))
+
+        self.bunch_list = [
+            Bunch(self.ring, mp_number=1, alive=filling_pattern[i])
+            for i in range(self.ring.h)
+        ]
+        self.update_filling_pattern()
+        self.update_distance_between_bunches()
+        self.mpi_init()
+        self[self.mpi.bunch_num] = bunch
+
     def update_filling_pattern(self):
         """Update the beam filling pattern."""
         filling_pattern = []
diff --git a/tests/unit/tracking/test_particle.py b/tests/unit/tracking/test_particle.py
index 0d527f7..2f0a174 100644
--- a/tests/unit/tracking/test_particle.py
+++ b/tests/unit/tracking/test_particle.py
@@ -304,4 +304,25 @@ class TestBeam:
         assert mock_mpi_instance.comm.allgather.called
         beam.mpi_close()
         assert not beam.mpi_switch
-        assert beam.mpi is None
\ No newline at end of file
+        assert beam.mpi is None
+        
+    def test_init_bunch_list(self, demo_ring):
+        filling_pattern = np.ones((demo_ring.h,), dtype=bool)
+        filling_pattern[5] = False
+        filling_pattern[8:10] = False
+        bunch_list = [Bunch(demo_ring, mp_number=1, alive=filling_pattern[i]) for i in range(demo_ring.h)]
+        beam = Beam(demo_ring, bunch_list)
+        assert len(beam) == filling_pattern.sum()
+        np.testing.assert_array_equal(beam.filling_pattern, filling_pattern)
+        assert beam.distance_between_bunches is not None
+        
+    def test_init_bunch_list_mpi(self, demo_ring, generate_bunch):
+        filling_pattern = np.zeros((demo_ring.h,), dtype=bool)
+        filling_pattern[0] = True
+        beam = Beam(demo_ring)
+        beam.init_bunch_list_mpi(generate_bunch(), filling_pattern)
+        
+        assert len(beam) == 1
+        np.testing.assert_array_equal(beam.filling_pattern, filling_pattern)
+        assert beam.distance_between_bunches[0] == demo_ring.h
+        
\ No newline at end of file
-- 
GitLab