From def4b209365aeb324a5a331d4efb67aaf5e9e79a Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 23 Nov 2022 12:25:41 +0100
Subject: [PATCH] Add save and load methods for Bunch and Beam

Save and load methods allow to save Bunch and Beam object and reload them later on.
Load method for Beam allow to choose if mpi is used or not.
Initial modification from Watanyu.
---
 mbtrack2/tracking/particles.py     | 181 ++++++++++++++++++++++++++++-
 mbtrack2/tracking/wakepotential.py |   3 +
 2 files changed, 181 insertions(+), 3 deletions(-)

diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index 49b9adf..f086eaa 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -7,6 +7,7 @@ import numpy as np
 import matplotlib.pyplot as plt
 import seaborn as sns
 import pandas as pd
+import h5py as hp
 from scipy.constants import c, m_e, m_p, e
 
 class Particle:
@@ -57,6 +58,15 @@ class Bunch:
         Can be set to False to gain a speed increase.
     alive : bool, optional
         If False, the bunch is defined as empty.
+    load_from_file : str, optional
+        Name of bunch save file generated by save method.
+        If None, the bunch is initialized using the input parameters.
+        Otherwise, the bunch from the file is loaded and the other inputs are 
+        ignored.
+        Default is None.
+    load_suffix : str or int, optional
+        Suffix to the group name used to load the data from the HDF5 file.
+        Default is None.
         
     Attributes
     ----------
@@ -88,7 +98,12 @@ class Bunch:
         Initialize bunch particles with 6D gaussian phase space.
     plot_phasespace(x_var="tau", y_var="delta", plot_type="j")
         Plot phase space.
-        
+    save(file_name)
+        Save bunch object data (6D phase space, current, and state) in an HDF5 
+        file format.
+    load(file_name)
+        Load data from a HDF5 file recorded by Bunch save method.
+    
     References
     ----------
     [1] Wiedemann, H. (2015). Particle accelerator physics. 4th edition. 
@@ -96,7 +111,7 @@ class Bunch:
     """
     
     def __init__(self, ring, mp_number=1e3, current=1e-3, track_alive=True,
-                 alive=True):
+                 alive=True, load_from_file=None, load_suffix=None):
         
         self.ring = ring
         if not alive:
@@ -113,11 +128,13 @@ class Bunch:
         
         self.particles = np.zeros(self.mp_number, self.dtype)
         self.track_alive = track_alive
-        
         self.alive = np.ones((self.mp_number,),dtype=bool)
         self.current = current
         if not alive:
             self.alive = np.zeros((self.mp_number,),dtype=bool)
+            
+        if load_from_file is not None:
+            self.load(load_from_file, load_suffix, track_alive)
         
     def __len__(self):
         """Return the number of alive particles"""
@@ -391,6 +408,87 @@ class Bunch:
             raise ValueError("Plot type not recognised.")
             
         return fig
+    
+    def save(self, file_name, suffix=None, mpi_comm=None):
+        """
+        Save bunch object data (6D phase space, current, and state) in an HDF5 
+        file format.
+        
+        The output file is named as "<file_name>.hdf5".
+
+        Parameters
+        ----------
+        file_name : str
+            Name of the output file.
+        suffix : str or int, optional
+            Suffix to the group name used to save the data into the HDF5 file.
+            If None, it is not used.
+        mpi_comm : MPI communicator, optional
+            For internal use if mpi is used in Beam objects.
+            Default is None.
+        """
+        
+        if mpi_comm is None:
+            f = hp.File(file_name + ".hdf5", "a", libver="earliest")
+        else:
+            f = hp.File(file_name+ ".hdf5", "a", libver='earliest', 
+                        driver='mpio', comm=mpi_comm)
+        
+        if suffix is None:
+            group_name = "Bunch"
+        else:
+            group_name = "Bunch_" + str(suffix)
+            
+        g = f.create_group(group_name)
+        g.create_dataset("alive", (self.mp_number,), dtype=bool)
+        g.create_dataset("phasespace", (self.mp_number, 6), dtype=float)
+        g.create_dataset("current", (1,), dtype=float)
+        
+        f[group_name]["alive"][:] = self.alive
+        f[group_name]["current"][:] = self.current
+        for i, dim in enumerate(self):
+            f[group_name]["phasespace"][:,i] = self.particles[dim]
+        
+        f.close()
+        
+    def load(self, file_name, suffix=None, track_alive=True):
+        """
+        Load data from a HDF5 file recorded by Bunch save method.
+
+        Parameters
+        ----------
+        file_name : str
+            Path to the HDF5 file where the Bunch data is stored.
+        suffix : str or int, optional
+            Suffix to the group name used to load the data from the HDF5 file.
+            If None, it is not used.
+        track_alive : bool, optional
+            If False, the code no longer take into account alive/dead particles.
+            Should be set to True if element such as apertures are used.
+            Can be set to False to gain a speed increase.
+        """
+        
+        f = hp.File(file_name, "r")
+        
+        if suffix is None:
+            group_name = "Bunch"
+        else:
+            group_name = "Bunch_" + str(suffix)
+
+        self.mp_number = len(f[group_name]['alive'][:])
+        
+        for i, dim in enumerate(self):
+            self.particles[dim] = f[group_name]["phasespace"][:,i]
+        
+        self.alive = f[group_name]['alive'][:]
+        if f[group_name]['current'][:][0] != 0:
+            self.current = f[group_name]['current'][:][0]
+        else:
+            self.charge_per_mp = 0
+            
+        self.track_alive = track_alive
+        
+        f.close()
         
 class Beam:
     """
@@ -447,6 +545,10 @@ class Beam:
         Call mpi_gather and switch off MPI parallelisation
     plot(var, option=None)
         Plot variables with respect to bunch number.
+    save(file_name)
+        Save beam object data in an HDF5 file format.
+    load(file_name, mpi)
+        Load data from a HDF5 file recorded by Beam save method.
     """
     
     def __init__(self, ring, bunch_list=None):
@@ -808,4 +910,77 @@ class Beam:
        
         return fig
         
+    def save(self, file_name):
+        """
+        Save beam object data in an HDF5 file format.
+        
+        The output file is named as "<file_name>.hdf5".
+
+        Parameters
+        ----------
+        file_name : str
+            Name of the output file.
+        """
+        if self.mpi_switch is True:
+            for i, bunch in enumerate(self):
+                if i in self.mpi.table[:,1]:
+                    if i == self.mpi.bunch_num:
+                        mp_number = self[self.mpi.bunch_num].mp_number
+                        self.mpi.comm.bcast(mp_number, root=self.mpi.rank)
+                        self[self.mpi.bunch_num].save(file_name, 
+                                                      self.mpi.bunch_num, 
+                                                      self.mpi.comm)
+                    else:
+                        mp_number = None
+                        mp_number = self.mpi.comm.bcast(mp_number, root=self.mpi.bunch_to_rank(i))
+                        f = hp.File(file_name+ ".hdf5", "a", libver='earliest', 
+                                    driver='mpio', comm=self.mpi.comm)
+                        group_name = "Bunch_" + str(i)
+                        g = f.create_group(group_name)
+                        g.create_dataset("alive", (mp_number,), dtype=bool)
+                        g.create_dataset("phasespace", (mp_number, 6), dtype=float)
+                        g.create_dataset("current", (1,), dtype=float)  
+                        f.close()
+                else:
+                    bunch.save(file_name, 
+                               i,
+                               self.mpi.comm)
+        else:
+            for i, bunch in enumerate(self):
+                bunch.save(file_name, i)
     
+    def load(self, file_name, mpi, track_alive=True):
+        """
+        Load data from a HDF5 file recorded by Beam save method.
+
+        Parameters
+        ----------
+        file_name : str
+            Path to the HDF5 file where the Beam data is stored.
+        mpi : bool
+            If True, only a single bunch is fully initialized on each core, the
+            other bunches are initialized with a single marco-particle.
+        track_alive : bool, optional
+            If False, the code no longer take into account alive/dead particles.
+            Should be set to True if element such as apertures are used.
+            Can be set to False to gain a speed increase.
+            The default is True.
+        """
+        if mpi is True:
+            self.__init__(self.ring)
+            f = hp.File(file_name, "r")
+            filling_pattern = []
+            for i in range(self.ring.h):
+                current = f["Bunch_" + str(i)]['current'][:][0]
+                filling_pattern.append(current)
+                    
+            self.init_beam(filling_pattern,
+                           mp_per_bunch=1,
+                           track_alive=track_alive,
+                           mpi=True)
+            self[self.mpi.bunch_num].load(file_name, self.mpi.bunch_num)
+        else:
+            for i, bunch in enumerate(self):
+                bunch.load(file_name, i, track_alive)
+            self.update_filling_pattern()
+            self.update_distance_between_bunches()
\ No newline at end of file
diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py
index fcd0d50..94a2436 100644
--- a/mbtrack2/tracking/wakepotential.py
+++ b/mbtrack2/tracking/wakepotential.py
@@ -77,6 +77,9 @@ class WakePotential(Element):
         self.ring = ring
         self.n_bin = n_bin
         self.check_sampling()
+        
+        # Suppress numpy warning for floating-point operations.
+        np.seterr(invalid='ignore')
             
     def charge_density(self, bunch):
         """
-- 
GitLab