From 1fd72e167f03a1eece59ae6f761c16a40127048b Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Thu, 20 Feb 2025 12:08:32 +0100
Subject: [PATCH] Improve docstrings and tests

---
 mbtrack2/tracking/beam_ion_effects.py        | 218 ++++++++++++-------
 tests/unit/tracking/test_beam_ion_effects.py |  30 ++-
 2 files changed, 166 insertions(+), 82 deletions(-)

diff --git a/mbtrack2/tracking/beam_ion_effects.py b/mbtrack2/tracking/beam_ion_effects.py
index dec5433..5e49ca5 100644
--- a/mbtrack2/tracking/beam_ion_effects.py
+++ b/mbtrack2/tracking/beam_ion_effects.py
@@ -6,7 +6,6 @@ IonMonitor
 IonAperture
 IonParticles
 """
-import warnings
 from abc import ABCMeta
 from functools import wraps
 from itertools import count
@@ -36,14 +35,18 @@ class IonMonitor(Monitor, metaclass=ABCMeta):
     total_size : int
         The total number of steps to be simulated.
     file_name : str, optional
-        The name of the HDF5 file to store the data. If not provided, a new file will be created. Defaults to None.
+        The name of the HDF5 file to store the data. If not provided, a new 
+        file will be created. 
+        Defaults to None.
 
     Methods
     -------
-    monitor_init(group_name, save_every, buffer_size, total_size, dict_buffer, dict_file, file_name=None, dict_dtype=None)
+    monitor_init(group_name, save_every, buffer_size, total_size, dict_buffer, 
+                 dict_file, file_name=None, dict_dtype=None)
         Initialize the monitor object.
     track(bunch)
         Tracking method for the element.
+
     Raises
     ------
     ValueError
@@ -96,13 +99,18 @@ class IonMonitor(Monitor, metaclass=ABCMeta):
         total_size : int
             The total number of steps to be simulated.
         dict_buffer : dict
-            A dictionary containing the names and sizes of the attribute buffers.
+            A dictionary containing the names and sizes of the attribute 
+            buffers.
         dict_file : dict
-            A dictionary containing the names and shapes of the datasets to be created.
+            A dictionary containing the names and shapes of the datasets to be 
+            created.
         file_name : str, optional
-            The name of the HDF5 file to store the data. If not provided, a new file will be created. Defaults to None.
+            The name of the HDF5 file to store the data. If not provided, a new 
+            file will be created. 
+            Defaults to None.
         dict_dtype : dict, optional
-            A dictionary containing the names and data types of the datasets. Defaults to None.
+            A dictionary containing the names and data types of the datasets. 
+            Defaults to None.
 
         Raises
         ------
@@ -157,8 +165,10 @@ class IonAperture(ElipticalAperture):
     """
     Class representing an ion aperture.
 
-    Inherits from ElipticalAperture. Unlike in ElipticalAperture, ions are removed from IonParticles instead of just being flagged as not "alive".
-    For beam-ion simulations there are too many lost particles and it is better to remove them.
+    Inherits from ElipticalAperture. Unlike in ElipticalAperture, ions are 
+    removed from IonParticles instead of just being flagged as not "alive".
+    For beam-ion simulations there are too many lost particles and it is better 
+    to remove them.
 
     Attributes
     ----------
@@ -208,15 +218,33 @@ class IonParticles(Bunch):
     ring : Synchrotron class object
         The ring object representing the accelerator ring.
     track_alive : bool, optional
-        Flag indicating whether to track the alive particles. Default is False.
+        Flag indicating whether to track the alive particles. 
+        Default is False.
     alive : bool, optional
-        Flag indicating whether the particles are alive. Default is True.
+        Flag indicating whether the particles are alive. 
+        Default is True.
+        
+    Attributes
+    ----------
+    charge : float
+        Bunch charge in [C].
+    mean : array of shape (4,)
+        Charge-weighted mean values for x, xp, y and yp coordinates.
+    mean_xy : tuple of 2 float
+        Charge-weighted mean values for x and y coordinates.
+    std : array of shape (4,)
+        Charge-weighted std values for x, xp, y and yp coordinates.
+    mean_std_xy : tuple of 4 float
+        Charge-weighted mean and std values for x and y coordinates.
+        
     Methods:
     --------
-    generate_as_a_distribution(electron_bunch)
-        Generates the particle positions based on a normal distribution, taking distribution parameters from an electron bunch.
-    generate_from_random_samples(electron_bunch)
-        Generates the particle positions and times based on random samples from electron positions.
+    generate_as_a_distribution(electron_bunch, charge)
+        Generates the particle positions based on a normal distribution, taking 
+        distribution parameters from an electron bunch.
+    generate_from_random_samples(electron_bunch, charge)
+        Generates the particle positions and times based on random samples from 
+        electron positions.
     """
 
     def __init__(self,
@@ -261,9 +289,6 @@ class IonParticles(Bunch):
         return self["charge"].sum()
 
     def _mean_weighted(self, coord):
-        """
-        Return the mean position of alive particles for each coordinates.
-        """
         if self.charge == 0:
             return np.zeros_like(coord, dtype=float)
         else:
@@ -272,10 +297,6 @@ class IonParticles(Bunch):
             return np.squeeze(np.array(mean))
 
     def _mean_std_weighted(self, coord):
-        """
-        Return the standard deviation of the position of alive
-        particles for each coordinates.
-        """
         if self.charge == 0:
             return np.zeros_like(coord,
                                  dtype=float), np.zeros_like(coord,
@@ -293,35 +314,42 @@ class IonParticles(Bunch):
 
     @property
     def mean(self):
+        """Return charge-weighted mean values for x, xp, y and yp coordinates."""
         coord = ["x", "xp", "y", "yp"]
         return self._mean_weighted(coord)
 
     @property
-    def mean_weighted(self):
+    def mean_xy(self):
+        """Return charge-weighted mean values for x and y coordinates."""
         coord = ["x", "y"]
         mean = self._mean_weighted(coord)
         return mean[0], mean[1]
 
     @property
     def std(self):
+        """Return charge-weighted std values for x, xp, y and yp coordinates."""
         coord = ["x", "xp", "y", "yp"]
         _, std = self._mean_std_weighted(coord)
         return std
 
     @property
-    def mean_std_weighted(self):
+    def mean_std_xy(self):
+        """Return charge-weighted mean and std values for x and y coordinates."""
         coord = ["x", "y"]
         mean, std = self._mean_std_weighted(coord)
         return mean[0], mean[1], std[0], std[1]
 
     def generate_as_a_distribution(self, electron_bunch, charge):
         """
-        Generates the particle positions based on a normal distribution, taking distribution parameters from an electron bunch.
+        Generates the particle positions based on a normal distribution, taking 
+        distribution parameters from an electron bunch.
 
         Parameters:
         ----------
         electron_bunch : Bunch
             An instance of the Bunch class representing the electron bunch.
+        charge : float
+            Total ion charge generated in [C].
         """
         if electron_bunch.is_empty:
             raise ValueError("Electron bunch is empty.")
@@ -353,12 +381,15 @@ class IonParticles(Bunch):
 
     def generate_from_random_samples(self, electron_bunch, charge):
         """
-        Generates the particle positions and times based on random samples from electron positions in the bunch.
+        Generates the particle positions and times based on random samples from 
+        electron positions in the bunch.
 
         Parameters:
         ----------
         electron_bunch : Bunch
             An instance of the Bunch class representing the electron bunch.
+        charge : float
+            Total ion charge generated in [C].
         """
         if electron_bunch.is_empty:
             raise ValueError("Electron bunch is empty.")
@@ -392,47 +423,71 @@ class IonParticles(Bunch):
 class BeamIonElement(Element):
     """
     Represents an element for simulating beam-ion interactions.
+    
+    Apertures and monitors for the ion beam (instances of IonAperture and 
+    IonMonitor) can be added to tracking after the BeamIonElement object has 
+    been initialized by using:
+        beam_ion.apertures.append(aperture)
+        beam_ion.monitors.append(monitor)
+        
+    If the a IonMonitor object is used to record the ion beam, at the end of 
+    tracking the user should close the monitor, for example by calling:
+        beam_ion.monitors[0].close()
 
     Parameters
     ----------
     ion_mass : float
-        The mass of the ions in kg.
+        The mass of the ions in [kg].
     ion_charge : float
-        The charge of the ions in Coulomb.
+        The charge of the ions in [C].
     ionization_cross_section : float
-        The cross section of ionization in meters^2.
+        The cross section of ionization in [m^2].
     residual_gas_density : float
-        The residual gas density in meters^-3.
+        The residual gas density in [m^-3].
     ring : instance of Synchrotron()
         The ring.
-    ion_field_model : str
-        The ion field model, the options are 'weak' (acts on each macroparticle), 'strong' (acts on c.m.), 'PIC'.
+    ion_field_model : {'weak', 'strong', 'PIC'}
+        The ion field model used to update electron beam coordinates:
+            - 'weak': ion field acts on each macroparticle.
+            - 'strong': ion field acts on electron bunch c.m. 
+            - 'PIC': a PIC solver is used to get the ion electric field and the
+            result is interpolated on the electron bunch coordinates.
+        For both 'weak' and 'strong' models, the electric field is computed 
+        using the Bassetti-Erskine formula [1], so assuming a Gaussian beam 
+        distribution.
         For 'PIC' the PyPIC package is required.
-    electron_field_model : str
-        The electron field model, the options are 'weak', 'strong', 'PIC'.
+    electron_field_model : {'weak', 'strong', 'PIC'}
+        The electron field model, defined in the same way as ion_field_model.
     ion_element_length : float
-        The length of the beam-ion interaction region. For example, if only a single interaction point is used this should be equal to ring.L. 
-    x_radius : float
-        The x radius of the aperture.
-    y_radius : float
-        The y radius of the aperture.
-    n_steps : int
-        The number of records in the built-in ion beam monitor. Should be number of turns times number of bunches because the monitor records every turn after each bunch passage.
+        The length of the beam-ion interaction region. For example, if only a 
+        single interaction point is used this should be equal to ring.L. 
     n_ion_macroparticles_per_bunch : int, optional
-        The number of ion macroparticles generated per electron bunch passed. Defaults to 30.
-    ion_beam_monitor_name : str, optional
-         If provided, the name of the ion monitor output file. It must end with an extension '.hdf5'.
-         If None, no ion monitor file is generated.
-    use_ion_phase_space_monitor : bool, optional
-        Whether to use the ion phase space monitor.
-    generate_method : str, optional
-        The method to generate the ion macroparticles, the options are 'distribution', 'samples'. Defaults to 'distribution'. 
-        'distribution' generates a distribution statistically equivalent to the distribution of electrons. 
-        'samples' generates ions from random samples of electron positions.
+        The number of ion macroparticles generated per electron bunch passage. 
+        Defaults to 30.
+    generate_method : {'distribution', 'samples'}, optional
+        The method to generate the ion macroparticles:
+            - 'distribution' generates a distribution statistically equivalent 
+            to the distribution of electrons. 
+            - 'samples' generates ions from random samples of electron 
+            positions.
+        Defaults to 'distribution'.
+    generate_ions : bool, optional
+        If True, generate ions during BeamIonElement.track calls.
+        Default is True.
+    beam_ion_interaction : bool, optional
+        If True, update both beam and ion beam coordinate due to beam-ion 
+        interaction during BeamIonElement.track calls.
+        Default is True.
+    ion_drift : bool, optional
+        If True, update ion beam coordinate due to drift time between bunches.
+        Default is True.
 
     Methods
     -------
-    __init__(ion_mass, ion_charge, ionization_cross_section, residual_gas_density, ring, ion_field_model, electron_field_model, ion_element_length, n_steps, x_radius, y_radius, ion_beam_monitor_name=None, use_ion_phase_space_monitor=False, n_ion_macroparticles_per_bunch=30, generate_method='distribution')
+    __init__(ion_mass, ion_charge, ionization_cross_section, 
+             residual_gas_density, ring, ion_field_model, electron_field_model, 
+             ion_element_length, n_ion_macroparticles_per_bunch=30, 
+             generate_method='distribution')
         Initializes the BeamIonElement object.
     parallel(track)
         Defines the decorator @parallel to handle tracking of Beam() objects.
@@ -448,9 +503,13 @@ class BeamIonElement(Element):
     Raises
     ------
     UserWarning
-        If the BeamIonMonitor object is used, the user should call the close() method at the end of tracking.
-    NotImplementedError
-        If the ion phase space monitor is used.
+        If the BeamIonMonitor object is used, the user should call the close() 
+        method at the end of tracking.
+        
+    References
+    ----------
+    [1] : Bassetti, M., & Erskine, G. A. (1980). Closed expression for the 
+    electrical field of a two-dimensional Gaussian charge (No. ISR-TH-80-06).
     """
 
     def __init__(self,
@@ -463,7 +522,10 @@ class BeamIonElement(Element):
                  electron_field_model,
                  ion_element_length,
                  n_ion_macroparticles_per_bunch=30,
-                 generate_method='distribution'):
+                 generate_method='distribution',
+                 generate_ions=True,
+                 beam_ion_interaction=True,
+                 ion_drift=True):
         self.ring = ring
         self.bunch_spacing = ring.L / ring.h
         self.ion_mass = ion_mass
@@ -481,16 +543,10 @@ class BeamIonElement(Element):
             mp_number=1,
             ion_element_length=self.ion_element_length,
             ring=self.ring)
-        self.ion_beam["x"] = 0
-        self.ion_beam["xp"] = 0
-        self.ion_beam["y"] = 0
-        self.ion_beam["yp"] = 0
-        self.ion_beam["tau"] = 0
-        self.ion_beam["delta"] = 0
 
-        self.generate_ions = True
-        self.beam_ion_interaction = True
-        self.ion_drift = True
+        self.generate_ions = generate_ions
+        self.beam_ion_interaction = beam_ion_interaction
+        self.ion_drift = ion_drift
 
         # interfaces for apertures and montiors
         self.apertures = []
@@ -562,16 +618,17 @@ class BeamIonElement(Element):
 
     def _get_efields(self, first_beam, second_beam, field_model):
         """
-        Calculates the electromagnetic field of the second beam acting on the first beam for a given field model.
+        Calculates the electromagnetic field of the second beam acting on the 
+        first beam for a given field model.
     
         Parameters
         ----------
         first_beam : IonParticles or Bunch
-            The first beam, represented as an instance of IonParticles() or Bunch().
+            The first beam, which is being acted on.
         second_beam : IonParticles or Bunch
-            The second beam, represented as an instance of IonParticles() or Bunch().
-        field_model : str, optional
-            The field model used for the interaction. Options are 'weak', 'strong', or 'PIC'.
+            The second beam, which is generating an electric field.
+        field_model : {'weak', 'strong', 'PIC'}
+            The field model used for the interaction.
     
         Returns
         -------
@@ -582,10 +639,10 @@ class BeamIonElement(Element):
         """
         if not field_model in ["weak", "strong", "PIC"]:
             raise ValueError(
-                f"The implementation for required beam-ion interaction model {field_model} is not implemented"
-            )
+                f"The implementation for required beam-ion interaction model \
+                {field_model} is not implemented")
         if isinstance(second_beam, IonParticles):
-            sb_mx, sb_my, sb_stdx, sb_stdy = second_beam.mean_std_weighted
+            sb_mx, sb_my, sb_stdx, sb_stdy = second_beam.mean_std_xy
         else:
             sb_mx, sb_stdx = (
                 second_beam["x"].mean(),
@@ -608,7 +665,7 @@ class BeamIonElement(Element):
 
         elif field_model == "strong":
             if isinstance(first_beam, IonParticles):
-                fb_mx, fb_my = first_beam.mean_weighted
+                fb_mx, fb_my = first_beam.mean_xy
             else:
                 fb_mx, fb_my = (
                     first_beam["x"].mean(),
@@ -648,18 +705,19 @@ class BeamIonElement(Element):
                                prefactor,
                                field_model="strong"):
         """
-        Calculates the new momentum of the first beam due to the interaction with the second beam.
+        Calculates the new momentum of the first beam due to the interaction 
+        with the second beam.
         
         Parameters
         ----------
         first_beam : IonParticles or Bunch
-            The first beam, represented as an instance of IonParticles() or Bunch().
+            The first beam, which is being acted on.
         second_beam : IonParticles or Bunch
-            The second beam, represented as an instance of IonParticles() or Bunch().
+            The second beam, which is generating an electric field.
         prefactor : float
             A scaling factor applied to the calculation of the new momentum.
-        field_model : str
-            The field model used for the interaction. Options are 'weak', 'strong', or 'PIC'.
+        field_model : {'weak', 'strong', 'PIC'}
+            The field model used for the interaction.
             Default is "strong".
         
         Returns
@@ -689,7 +747,7 @@ class BeamIonElement(Element):
 
         Parameters
         ----------
-        electron_bunch : ElectronBunch
+        electron_bunch : Bunch
             The electron bunch used to generate new ions.
 
         Returns
diff --git a/tests/unit/tracking/test_beam_ion_effects.py b/tests/unit/tracking/test_beam_ion_effects.py
index d9c9fc8..ca42a7c 100644
--- a/tests/unit/tracking/test_beam_ion_effects.py
+++ b/tests/unit/tracking/test_beam_ion_effects.py
@@ -237,9 +237,15 @@ class TestBeamIonElement:
                              [('weak','weak'), ('weak','strong'), 
                               ('strong','weak'), ('strong', 'strong')])
     def test_track_bunch_partially_lost(self, generate_beam_ion, small_bunch, ion_field_model, electron_field_model):
-        small_bunch.alive[0:5] = False
         beam_ion = generate_beam_ion(ion_field_model=ion_field_model, electron_field_model=electron_field_model)
         assert_attr_changed(beam_ion, small_bunch, attrs_changed=["xp","yp"])
+        charge_gen = beam_ion.ion_beam.charge
+        
+        # loose half of electron bunch
+        small_bunch.alive[0:5] = False
+        assert_attr_changed(beam_ion, small_bunch, attrs_changed=["xp","yp"])
+        
+        assert np.isclose(beam_ion.ion_beam.charge, charge_gen*1.5)
 
     @pytest.mark.parametrize('ion_field_model, electron_field_model',
                              [('weak','weak'), ('weak','strong'), 
@@ -336,4 +342,24 @@ class TestBeamIonElement:
     
         beam_ion.clear_ions()
         assert len(beam_ion.ion_beam["x"]) == 1
-        assert beam_ion.ion_beam["x"][0] == 0
\ No newline at end of file
+        assert beam_ion.ion_beam["x"][0] == 0
+        
+    # Tracking with a pre-set cloud
+    def test_track_preset_cloud(self, generate_beam_ion, small_bunch, generate_ion_particles):
+        beam_ion = generate_beam_ion()
+        assert beam_ion.ion_beam.charge == 0
+        
+        preset_ion_particles = generate_ion_particles(mp_number=1000)
+        preset_charge = 1e-9
+        preset_ion_particles.generate_as_a_distribution(small_bunch, preset_charge)
+        beam_ion.ion_beam += preset_ion_particles
+        assert np.isclose(beam_ion.ion_beam.charge, preset_charge)
+        
+        beam_ion.generate_ions = False
+        assert_attr_changed(beam_ion, small_bunch, attrs_changed=["xp","yp"])
+        assert np.isclose(beam_ion.ion_beam.charge, preset_charge)
+        
+        beam_ion.generate_ions = True
+        assert_attr_changed(beam_ion, small_bunch, attrs_changed=["xp","yp"])
+        assert beam_ion.ion_beam.charge > preset_charge
+        
\ No newline at end of file
-- 
GitLab