diff --git a/docs/source/conf.py b/docs/source/conf.py
index feb3b08bd950890656187542c122e198793e488d..10f665e0560a364251934478d70f7c1b1b244a20 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -12,6 +12,7 @@
 #
 import os
 import sys
+
 sys.path.insert(0, os.path.abspath('.'))
 sys.path.insert(0, os.path.abspath("../"))
 sys.path.insert(0, os.path.abspath("../../"))
diff --git a/mbtrack2/__init__.py b/mbtrack2/__init__.py
index d3383ef03200638c6247696f30aa04af63525fac..abe9587cd67fb5ea64ba86fc5c4464fc82869b5f 100644
--- a/mbtrack2/__init__.py
+++ b/mbtrack2/__init__.py
@@ -4,3 +4,45 @@ from mbtrack2.impedance import *
 from mbtrack2.instability import *
 from mbtrack2.tracking import *
 from mbtrack2.utilities import *
+
+try:
+    DYNAMIC_VERSIONING = True
+    import os
+    import subprocess
+
+    worktree = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+    gitdir = worktree + "/.git/"
+    with open(os.devnull, "w") as devnull:
+        __version__ = subprocess.check_output(
+            "git --git-dir=" + gitdir + " --work-tree=" + worktree +
+            " describe --long --dirty --abbrev=10 --tags",
+            shell=True,
+            stderr=devnull,
+        )
+    __version__ = __version__.decode("utf-8").rstrip()  # remove trailing \n
+    # remove commit hash to conform to PEP440:
+    split_ = __version__.split("-")
+    __version__ = split_[0]
+    if split_[1] != "0":
+        __version__ += "." + split_[1]
+    dirty = "dirty" in split_[-1]
+except:
+    DYNAMIC_VERSIONING = False
+    from ._version import __version__
+
+    dirty = False
+
+print(("mbtrack2 version " + __version__))
+if dirty:
+    print("(dirty git work tree)")
+print(50 * '-')
+print(
+    "If used in a publication, please cite mbtrack2 paper and the zenodo archive for the corresponding code version (and other papers for more specific features)."
+)
+print(
+    "[1] A. Gamelin, W. Foosang, N. Yamamoto, V. Gubaidulin and R. Nagaoka, “mbtrack2”. Zenodo, Mar. 25, 2024. doi: 10.5281/zenodo.10871040."
+)
+print(
+    "[2] A. Gamelin, W. Foosang, and R. Nagaoka, “mbtrack2, a Collective Effect Library in Python”, presented at the 12th Int. Particle Accelerator Conf. (IPAC'21), Campinas, Brazil, May 2021, paper MOPAB070."
+)
+print("\n")
diff --git a/mbtrack2/impedance/wakefield.py b/mbtrack2/impedance/wakefield.py
index 458130fe004e0ee0224c026a3cf51e5c9902cdc2..3325ed5a7060b5be0f69c6fc5024507c849e4da1 100644
--- a/mbtrack2/impedance/wakefield.py
+++ b/mbtrack2/impedance/wakefield.py
@@ -851,6 +851,9 @@ class WakeField:
         Save WakeField element to file.
     load(file)
         Load WakeField element from file.
+    save_to_pyheadtail(filename)
+        Save WakeField to PyHEADTAIL format.
+
     """
 
     def __init__(self, structure_list=None, name=None):
@@ -998,6 +1001,26 @@ class WakeField:
         with open(file, "wb") as f:
             pickle.dump(self, f)
 
+    def save_to_pyheadtail(self, file):
+        """Saves wakefield model to pyheadtail format.
+
+        Parameters
+        ----------
+        file : str
+            filename to save the wakefield model
+
+        Returns
+        -------
+        None.
+        """
+        df_to_save = pd.concat(
+            [self.Wlong.data, self.Wxdip.data, self.Wydip.data], axis=1)
+        df_to_save = df_to_save.drop('imag', axis=1)
+        df_to_save.to_csv(file,
+                          sep='\t',
+                          header=[0.0, 0.0, 0.0],
+                          index_label='0.0')
+
     @staticmethod
     def load(file):
         """
diff --git a/mbtrack2/instability/ions.py b/mbtrack2/instability/ions.py
index 2ca737772d467d29b91208b443c615abe0bef25a..ee20a309c184c3af144ba4221554aecb02fd263e 100644
--- a/mbtrack2/instability/ions.py
+++ b/mbtrack2/instability/ions.py
@@ -17,6 +17,7 @@ from scipy.constants import (
     physical_constants,
     pi,
 )
+from scipy.special import k0
 
 rp = 1 / (4*pi*epsilon_0) * e**2 / (m_p * c**2)
 re = physical_constants["classical electron radius"][0]
@@ -25,12 +26,12 @@ re = physical_constants["classical electron radius"][0]
 def ion_cross_section(ring, ion):
     """
     Compute the collisional ionization cross section.
-    
+
     Compute the inelastic collision cross section between a molecule or an atom
-    by a relativistic electron using the relativistic Bethe asymptotic formula 
+    by a relativistic electron using the relativistic Bethe asymptotic formula
     [1].
-    
-    Values of M02 and C02 from [2].
+
+    Values of M02 and C02 from [2] and [3] (values of constants are independent of beam energy).
 
     Parameters
     ----------
@@ -42,29 +43,36 @@ def ion_cross_section(ring, ion):
     -------
     sigma : float
         Cross section in [m**2].
-        
+
     References
     ----------
-    [1] : M. Inokuti, "Inelastic collisions of fast charged particles with 
-    atoms and molecules-the bethe theory revisited", Reviews of modern physics 
+    [1] : M. Inokuti, "Inelastic collisions of fast charged particles with
+    atoms and molecules-the bethe theory revisited", Reviews of modern physics
     43 (1971).
-    [2] : P. F. Tavares, "Bremsstrahlung detection of ions trapped in the EPA 
+    [2] : P. F. Tavares, "Bremsstrahlung detection of ions trapped in the EPA
     electron beam", Part. Accel. 43 (1993).
+    [3] : A. G. Mathewson, S. Zhang, "Beam-gas ionisation cross sections at 7.0 TEV", CERN Tech. rep. Vacuum-Technical-Note-96-01. https://cds.cern.ch/record/1489148/
 
     """
     if ion == "CO":
         M02 = 3.7
-        C0 = 35.1
+        C0 = 35.14
     elif ion == "H2":
-        M02 = 0.7
-        C0 = 8.1
+        M02 = 0.695
+        C0 = 8.115
+    elif ion == "CO2":
+        M02 = 5.75
+        C0 = 55.92
+    elif ion == "CH4":
+        M02 = 4.23
+        C0 = 42.85
     else:
         raise NotImplementedError
 
-    sigma = 4 * pi * (hbar / m_e / c)**2 * (
-        M02 * (1 / ring.beta**2 * np.log(ring.beta**2 /
-                                         (1 - ring.beta**2)) - 1) +
-        C0 / ring.beta**2)
+    sigma = (4 * pi * (hbar / m_e / c)**2 *
+             (M02 * (1 / ring.beta**2 * np.log(ring.beta**2 /
+                                               (1 - ring.beta**2)) - 1) +
+              C0 / ring.beta**2))
 
     return sigma
 
@@ -94,36 +102,27 @@ def ion_frequency(N,
     dim : "y" o "x", optional
         Dimension to consider. The default is "y".
     express : str, optional
-        Expression to use to compute the ion oscillation frequency. 
-        The default is "coupling" corresponding to Gaussian electron and ion 
+        Expression to use to compute the ion oscillation frequency.
+        The default is "coupling" corresponding to Gaussian electron and ion
         distributions with coupling [1].
-        Also possible is "no_coupling" corresponding to Gaussian electron and 
+        Also possible is "no_coupling" corresponding to Gaussian electron and
         ion distributions without coupling [2].
 
     Returns
     -------
     f : float or array
         Ion oscillation frequencies in [Hz].
-        
+
     References
     ----------
-    [1] : T. O. Raubenheimer and F. Zimmermann, "Fast beam-ion instability. I. 
+    [1] : T. O. Raubenheimer and F. Zimmermann, "Fast beam-ion instability. I.
     linear theory and simulations", Physical Review E 52 (1995).
-    [2] : G. V. Stupakov, T. O. Raubenheimer, and F. Zimmermann, "Fast beam-ion 
+    [2] : G. V. Stupakov, T. O. Raubenheimer, and F. Zimmermann, "Fast beam-ion
     instability. II. effect of ion decoherence", Physical Review E 52 (1995).
 
     """
 
-    if ion == "CO":
-        A = 28
-    elif ion == "H2":
-        A = 2
-    elif ion == "CH4":
-        A = 18
-    elif ion == "H2O":
-        A = 16
-    elif ion == "CO2":
-        A = 44
+    ion_mass = {"CO": 28, "H2": 2, "CH4": 18, "H20": 16, "CO2": 44}
 
     if dim == "y":
         pass
@@ -137,43 +136,45 @@ def ion_frequency(N,
     elif express == "no_coupling":
         k = 1
 
-    f = c * np.sqrt(2 * rp * N / (A * k * Lsep * sigmay *
-                                  (sigmax+sigmay))) / (2*pi)
+    f = (c * np.sqrt(2 * rp * N / (ion_mass[ion] * k * Lsep * sigmay *
+                                   (sigmax+sigmay))) / (2*pi))
 
     return f
 
 
-def fast_beam_ion(ring,
-                  Nb,
-                  nb,
-                  Lsep,
-                  sigmax,
-                  sigmay,
-                  P,
-                  T,
-                  beta,
-                  model="linear",
-                  delta_omega=0,
-                  ion="CO",
-                  dim="y"):
+def fast_beam_ion(
+    ring,
+    Nb,
+    nb,
+    Lsep,
+    sigmax,
+    sigmay,
+    P,
+    T,
+    beta,
+    model="linear",
+    delta_omega=0,
+    ion="CO",
+    dim="y",
+):
     """
     Compute fast beam ion instability rise time [1].
-    
-    Warning ! 
-    If model="linear", the rise time is an assymptotic grow time 
+
+    Warning !
+    If model="linear", the rise time is an assymptotic grow time
     (i.e. y ~ exp(sqrt(t/tau))) [1].
-    If model="decoherence", the rise time is an e-folding time 
+    If model="decoherence", the rise time is an e-folding time
     (i.e. y ~ exp(t/tau)) [2].
     If model="non-linear", the rise time is a linear growth time
     (i.e. y ~ t/tau) [3].
-    
+
     The linear model assumes that [1]:
         x,y << sigmax,sigmay
-    
+
     The decoherence model assumes that [2]:
-        Lsep << c / (2 * pi * ion_frequency) 
+        Lsep << c / (2 * pi * ion_frequency)
         Lsep << c / (2 * pi * betatron_frequency)
-        
+
     The non-linear model assumes that [3]:
         x,y >> sigmax,sigmay
 
@@ -212,16 +213,16 @@ def fast_beam_ion(ring,
     -------
     tau : float
         Instability rise time in [s].
-        
+
     References
     ----------
-    [1] : T. O. Raubenheimer and F. Zimmermann, "Fast beam-ion instability. I. 
+    [1] : T. O. Raubenheimer and F. Zimmermann, "Fast beam-ion instability. I.
     linear theory and simulations", Physical Review E 52 (1995).
-    [2] : G. V. Stupakov, T. O. Raubenheimer, and F. Zimmermann, "Fast beam-ion 
+    [2] : G. V. Stupakov, T. O. Raubenheimer, and F. Zimmermann, "Fast beam-ion
     instability. II. effect of ion decoherence", Physical Review E 52 (1995).
-    [3] : Chao, A. W., & Mess, K. H. (Eds.). (2013). Handbook of accelerator 
+    [3] : Chao, A. W., & Mess, K. H. (Eds.). (2013). Handbook of accelerator
     physics and engineering. World scientific. 3rd Printing. p417.
-    
+
     """
     if dim == "y":
         pass
@@ -239,10 +240,10 @@ def fast_beam_ion(ring,
 
     d_gas = P / (Boltzmann*T)
 
-    num = 4 * d_gas * sigma_i * beta * Nb**(3 / 2) * nb**2 * re * rp**(
-        1 / 2) * Lsep**(1 / 2) * c
-    den = 3 * np.sqrt(3) * ring.gamma * sigmay**(3 / 2) * (sigmay + sigmax)**(
-        3 / 2) * A**(1 / 2)
+    num = (4 * d_gas * sigma_i * beta * Nb**(3 / 2) * nb**2 * re *
+           rp**(1 / 2) * Lsep**(1 / 2) * c)
+    den = (3 * np.sqrt(3) * ring.gamma * sigmay**(3 / 2) *
+           (sigmay + sigmax)**(3 / 2) * A**(1 / 2))
 
     tau = den / num
 
@@ -276,11 +277,11 @@ def plot_critical_mass(ring, bunch_charge, bunch_spacing, n_points=1e4):
     Returns
     -------
     fig : figure
-        
+
     References
     ----------
-    [1] : Gamelin, A. (2018). Collective effects in a transient microbunching 
-    regime and ion cloud mitigation in ThomX (Doctoral dissertation, 
+    [1] : Gamelin, A. (2018). Collective effects in a transient microbunching
+    regime and ion cloud mitigation in ThomX (Doctoral dissertation,
     Université Paris-Saclay).
 
     """
@@ -310,3 +311,136 @@ def plot_critical_mass(ring, bunch_charge, bunch_spacing, n_points=1e4):
     ax.set_xlabel("Longitudinal position [m]")
 
     return fig
+
+
+def get_tavares_ion_distribution(x, sigma_x):
+    """
+    Get tranvserse ion distribution
+
+    Parameters
+    ----------
+    x: float, numpy array
+        an array defining the range of values in which to compute the distribution.
+    sigma_x: float
+        rms beam size of an electron bunch (assumed to have a Gaussian distribution) that focuses the ions.
+
+    Returns
+    -------
+    Transverse ion distribution density.
+
+    References
+    ----------
+    [1] Tavares, P. F. (1992). Transverse Distribution of Ions Trapped in an Electron Beam.
+    """
+    return (1 / (pi * np.sqrt(2 * pi) * sigma_x) *
+            k0(x**2 / (2 * sigma_x)**2) * np.exp(-(x**2) / (2 * sigma_x)**2))
+
+
+@np.vectorize
+def find_A_critical(sigma_x,
+                    sigma_y,
+                    bunch_intensity=int(1e9),
+                    bunch_spacing=0.85):
+    """
+    Compute critical mass for ion trapping
+
+    Parameters
+    ----------
+    sigma_x: float, ndarray
+        horizontal rms electron bunch size in [m]
+    sigma_y: float, ndarray
+        vertical rms electron bunch size in [m]
+    bunch_intensity: float
+        number of particles per bunch
+    bunch_spacing: float
+        spacing between bunches in [m]
+
+    Returns
+    -------
+    A_x, A_y: tuple of ndarrays
+        critical ion trapping mass in horizontal and vertical planes
+    References
+    ----------
+
+    """
+    A = bunch_intensity * bunch_spacing * rp / (2 * (sigma_x+sigma_y))
+    A_x, A_y = A / sigma_x, A / sigma_y
+    return A_x, A_y
+
+
+@np.vectorize
+def get_omega_i(sigma_x,
+                sigma_y,
+                ion_mass=28,
+                bunch_intensity=int(1e9),
+                bunch_spacing=0.85):
+    """
+    Compute ion bounce frequency
+
+    Parameters
+    ----------
+    sigma_x: float, numpy array
+        horizontal rms electron bunch size in [m]
+    sigma_y: float, numpy array
+        vertical rms electron bunch size in [m]
+    ion_mass: int
+        Ion mass normalised in [u]
+    bunch_intensity: float
+        number of particles per bunch
+    bunch_spacing: float
+        spacing between bunches in [m]
+
+    Returns
+    -------
+    omega_x, omega_y: tuple of ndarrays
+        Ion bounce oscillation frequency (horizontal, vertical) in [Hz/rad]
+    References
+    ----------
+    """
+    omega_i_times_sqrt_sigma = c * np.sqrt(4 * bunch_intensity * rp /
+                                           (3 * bunch_spacing *
+                                            (sigma_y+sigma_x) * ion_mass))
+    omega_x, omega_y = omega_i_times_sqrt_sigma / np.sqrt(
+        sigma_x), omega_i_times_sqrt_sigma / np.sqrt(sigma_y)
+    return omega_x, omega_y
+
+
+@np.vectorize
+def get_omega_e(
+    sigma_x,
+    sigma_y,
+    ionisation_cross_section,
+    residual_gas_density,
+    bunch_intensity,
+    lorenzt_gamma,
+):
+    """
+    Compute electron-in-ions bounce frequency
+
+    Parameters
+    ----------
+    sigma_x: float, numpy array
+        horizontal rms electron bunch size in [m]
+    sigma_y: float, numpy array
+        vertical rms electron bunch size in [m]
+    ionisation_cross_section:
+        collisional ionisation cross section in [m**2]
+    residual_gas_density: float
+        Residual gas density in [m**-3]
+    bunch_intensity: float
+        number of particles per bunch
+    gamma_r:
+
+    Returns
+    -------
+    omega_x, omega_y: tuple of ndarrays
+        Electron-in-ions oscillation frequency (horizontal, vertical) in [Hz/rad]
+    References
+    ----------
+    """
+    omega_e_times_sqrt_sigma = c * np.sqrt(
+        4 * ionisation_cross_section * residual_gas_density * bunch_intensity *
+        re / (lorenzt_gamma * (sigma_y+sigma_x)))
+    omega_x, omega_y = omega_e_times_sqrt_sigma / np.sqrt(
+        sigma_x), omega_e_times_sqrt_sigma / np.sqrt(sigma_y)
+    return omega_x, omega_y
diff --git a/mbtrack2/tracking/__init__.py b/mbtrack2/tracking/__init__.py
index a877b4689870ca5a9ca0b3cfcd5862b5f4c231ce..cef735267511672fb666bed6db12a187dd13ad3b 100644
--- a/mbtrack2/tracking/__init__.py
+++ b/mbtrack2/tracking/__init__.py
@@ -5,6 +5,12 @@ from mbtrack2.tracking.aperture import (
     LongitudinalAperture,
     RectangularAperture,
 )
+from mbtrack2.tracking.beam_ion_effects import (
+    BeamIonElement,
+    IonAperture,
+    IonMonitor,
+    IonParticles,
+)
 from mbtrack2.tracking.element import (
     Element,
     LongitudinalMap,
@@ -27,6 +33,7 @@ from mbtrack2.tracking.rf import (
     RFCavity,
     TunerLoop,
 )
+from mbtrack2.tracking.spacecharge import TransverseSpaceCharge
 from mbtrack2.tracking.synchrotron import Synchrotron
 from mbtrack2.tracking.wakepotential import (
     LongRangeResistiveWall,
diff --git a/mbtrack2/tracking/aperture.py b/mbtrack2/tracking/aperture.py
index c9316d1727d501073644e86dc07f58d56a59009d..46e5ea980334472973903bd89d9f29c749e3b05e 100644
--- a/mbtrack2/tracking/aperture.py
+++ b/mbtrack2/tracking/aperture.py
@@ -10,9 +10,9 @@ from mbtrack2.tracking.element import Element
 
 class CircularAperture(Element):
     """
-    Circular aperture element. The particles which are outside of the circle 
+    Circular aperture element. The particles which are outside of the circle
     are 'lost' and not used in the tracking any more.
-    
+
     Parameters
     ----------
     radius : float
@@ -29,21 +29,21 @@ class CircularAperture(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
-        alive = bunch.particles["x"]**2 + bunch.particles[
-            "y"]**2 < self.radius_squared
+        alive = (bunch.particles["x"]**2 + bunch.particles["y"]**2
+                 <= self.radius_squared)
         bunch.alive[~alive] = False
 
 
 class ElipticalAperture(Element):
     """
-    Eliptical aperture element. The particles which are outside of the elipse 
+    Eliptical aperture element. The particles which are outside of the elipse
     are 'lost' and not used in the tracking any more.
-    
+
     Parameters
     ----------
     X_radius : float
@@ -64,21 +64,21 @@ class ElipticalAperture(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
         alive = (bunch.particles["x"]**2 / self.X_radius_squared +
-                 bunch.particles["y"]**2 / self.Y_radius_squared < 1)
+                 bunch.particles["y"]**2 / self.Y_radius_squared <= 1)
         bunch.alive[~alive] = False
 
 
 class RectangularAperture(Element):
     """
-    Rectangular aperture element. The particles which are outside of the 
+    Rectangular aperture element. The particles which are outside of the
     rectangle are 'lost' and not used in the tracking any more.
-    
+
     Parameters
     ----------
     X_right : float
@@ -103,23 +103,23 @@ class RectangularAperture(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
 
-        if (self.X_left is None):
-            alive_X = np.abs(bunch.particles["x"]) < self.X_right
+        if self.X_left is None:
+            alive_X = np.abs(bunch.particles["x"]) <= self.X_right
         else:
-            alive_X = ((bunch.particles["x"] < self.X_right) &
-                       (bunch.particles["x"] > self.X_left))
+            alive_X = (bunch.particles["x"]
+                       <= self.X_right) & (bunch.particles["x"] >= self.X_left)
 
-        if (self.Y_bottom is None):
-            alive_Y = np.abs(bunch.particles["y"]) < self.Y_top
+        if self.Y_bottom is None:
+            alive_Y = np.abs(bunch.particles["y"]) <= self.Y_top
         else:
-            alive_Y = ((bunch.particles["y"] < self.Y_top) &
-                       (bunch.particles["y"] > self.Y_bottom))
+            alive_Y = (bunch.particles["y"]
+                       <= self.Y_top) & (bunch.particles["y"] >= self.Y_bottom)
 
         alive = alive_X & alive_Y
         bunch.alive[~alive] = False
@@ -127,9 +127,9 @@ class RectangularAperture(Element):
 
 class LongitudinalAperture(Element):
     """
-    Longitudinal aperture element. The particles which are outside of the 
+    Longitudinal aperture element. The particles which are outside of the
     longitudinal bounds are 'lost' and not used in the tracking any more.
-    
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -152,13 +152,13 @@ class LongitudinalAperture(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
 
-        alive = ((bunch.particles["tau"] < self.tau_up) &
-                 (bunch.particles["tau"] > self.tau_low))
+        alive = (bunch.particles["tau"]
+                 <= self.tau_up) & (bunch.particles["tau"] >= self.tau_low)
 
         bunch.alive[~alive] = False
diff --git a/mbtrack2/tracking/beam_ion_effects.py b/mbtrack2/tracking/beam_ion_effects.py
new file mode 100644
index 0000000000000000000000000000000000000000..458aa1705a855d98de4ff8aa514f467d67616244
--- /dev/null
+++ b/mbtrack2/tracking/beam_ion_effects.py
@@ -0,0 +1,694 @@
+"""
+Module implementing necessary functionalities for beam-ion interactions.
+Classes:
+BeamIonElement
+IonMonitor
+IonAperture
+IonParticles
+"""
+import warnings
+from abc import ABCMeta
+from functools import wraps
+from itertools import count
+
+import h5py as hp
+import numpy as np
+from numpy.random import choice, normal, uniform
+from scipy.constants import c, e
+
+from mbtrack2.tracking.aperture import ElipticalAperture
+from mbtrack2.tracking.element import Element
+from mbtrack2.tracking.monitors import Monitor
+from mbtrack2.tracking.particles import Beam, Bunch
+from mbtrack2.tracking.particles_electromagnetic_fields import (
+    _efieldn_mit,
+    get_displaced_efield,
+)
+
+
+class IonMonitor(Monitor, metaclass=ABCMeta):
+    """
+    A class representing an ion monitor.
+
+    Parameters
+    ----------
+    save_every : int
+        The number of steps between each save operation.
+    buffer_size : int
+        The size of the buffer to store intermediate data.
+    total_size : int
+        The total number of steps to be simulated.
+    file_name : str, optional
+        The name of the HDF5 file to store the data. If not provided, a new file will be created. Defaults to None.
+
+    Methods
+    -------
+    monitor_init(group_name, save_every, buffer_size, total_size, dict_buffer, dict_file, file_name=None, dict_dtype=None)
+        Initialize the monitor object.
+    track(bunch)
+        Tracking method for the element.
+    Raises
+    ------
+    ValueError
+        If total_size is not divisible by buffer_size.
+    """
+
+    _n_monitors = count(0)
+    file = None
+
+    def __init__(self, save_every, buffer_size, total_size, file_name=None):
+        group_name = "IonData_" + str(next(self._n_monitors))
+        dict_buffer = {
+            "mean": (6, buffer_size),
+            "std": (6, buffer_size),
+            "charge": (buffer_size, ),
+            "charge_per_mp": (buffer_size, ),
+        }
+        dict_file = {
+            "mean": (6, total_size),
+            "std": (6, total_size),
+            "charge": (total_size, ),
+            "charge_per_mp": (buffer_size, ),
+        }
+        self.monitor_init(group_name, save_every, buffer_size, total_size,
+                          dict_buffer, dict_file, file_name)
+
+        self.dict_buffer = dict_buffer
+        self.dict_file = dict_file
+
+    def monitor_init(self,
+                     group_name,
+                     save_every,
+                     buffer_size,
+                     total_size,
+                     dict_buffer,
+                     dict_file,
+                     file_name=None,
+                     dict_dtype=None):
+        """
+        Initialize the monitor object.
+
+        Parameters
+        ----------
+        group_name : str
+            The name of the HDF5 group to store the data.
+        save_every : int
+            The number of steps between each save operation.
+        buffer_size : int
+            The size of the buffer to store intermediate data.
+        total_size : int
+            The total number of steps to be simulated.
+        dict_buffer : dict
+            A dictionary containing the names and sizes of the attribute buffers.
+        dict_file : dict
+            A dictionary containing the names and shapes of the datasets to be created.
+        file_name : str, optional
+            The name of the HDF5 file to store the data. If not provided, a new file will be created. Defaults to None.
+        dict_dtype : dict, optional
+            A dictionary containing the names and data types of the datasets. Defaults to None.
+
+        Raises
+        ------
+        ValueError
+            If total_size is not divisible by buffer_size.
+        """
+        if self.file == None:
+            self.file = hp.File(file_name, "a", libver='earliest')
+
+        self.group_name = group_name
+        self.save_every = int(save_every)
+        self.total_size = int(total_size)
+        self.buffer_size = int(buffer_size)
+        if total_size % buffer_size != 0:
+            raise ValueError("total_size must be divisible by buffer_size.")
+        self.buffer_count = 0
+        self.write_count = 0
+        self.track_count = 0
+
+        # setup attribute buffers from values given in dict_buffer
+        for key, value in dict_buffer.items():
+            if dict_dtype == None:
+                self.__setattr__(key, np.zeros(value))
+            else:
+                self.__setattr__(key, np.zeros(value, dtype=dict_dtype[key]))
+        self.time = np.zeros((self.buffer_size, ), dtype=int)
+        # create HDF5 groups and datasets to save data from group_name and
+        # dict_file
+        self.g = self.file.require_group(self.group_name)
+        self.g.require_dataset("time", (self.total_size, ), dtype=int)
+        for key, value in dict_file.items():
+            if dict_dtype == None:
+                self.g.require_dataset(key, value, dtype=float)
+            else:
+                self.g.require_dataset(key, value, dtype=dict_dtype[key])
+
+        # create a dictionary which handle slices
+        slice_dict = {}
+        for key, value in dict_file.items():
+            slice_dict[key] = []
+            for i in range(len(value) - 1):
+                slice_dict[key].append(slice(None))
+        self.slice_dict = slice_dict
+
+    def track(self, object_to_save):
+        if self.track_count % self.save_every == 0:
+            self.to_buffer(object_to_save)
+        self.track_count += 1
+
+
+class IonAperture(ElipticalAperture):
+    """
+    Class representing an ion aperture.
+
+    Inherits from ElipticalAperture. Unlike in ElipticalAperture, ions are removed from IonParticles instead of just being flagged as not "alive".
+    For beam-ion simulations there are too many lost particles and it is better to remove them.
+
+    Attributes
+    ----------
+    X_radius_squared : float
+        The squared radius of the aperture in the x-direction.
+    Y_radius_squared : float
+        The squared radius of the aperture in the y-direction.
+
+    Methods
+    -------
+    track(bunch)
+        Tracking method for the element.
+
+    """
+
+    @Element.parallel
+    def track(self, bunch):
+        """
+        Tracking method for the element.
+        No bunch to bunch interaction, so written for Bunch objects and
+        @Element.parallel is used to handle Beam objects.
+
+        Parameters
+        ----------
+        bunch : Bunch or Beam object
+            The bunch object to be tracked.
+
+        """
+        alive = (bunch.particles["x"]**2 / self.X_radius_squared +
+                 bunch.particles["y"]**2 / self.Y_radius_squared <= 1)
+        for stat in ['x', 'xp', 'y', 'yp', 'tau', 'delta']:
+            bunch.particles[stat] = bunch.particles[stat][alive]
+        bunch.mp_number = len(bunch.particles['x'])
+        bunch.alive = np.ones((bunch.mp_number, ))
+
+
+class IonParticles(Bunch):
+    """
+    Class representing a collection of ion particles.
+
+    Parameters:
+    ----------
+    mp_number : int
+        The number of particles.
+    ion_element_length : float
+        The length of the ion segment.
+    ring : Synchrotron class object
+        The ring object representing the accelerator ring.
+    track_alive : bool, optional
+        Flag indicating whether to track the alive particles. Default is False.
+    alive : bool, optional
+        Flag indicating whether the particles are alive. Default is True.
+    Methods:
+    --------
+    generate_as_a_distribution(electron_bunch)
+        Generates the particle positions based on a normal distribution, taking distribution parameters from an electron bunch.
+    generate_from_random_samples(electron_bunch)
+        Generates the particle positions and times based on random samples from electron positions.
+    """
+
+    def __init__(self,
+                 mp_number,
+                 ion_element_length,
+                 ring,
+                 track_alive=False,
+                 alive=True):
+        self.ring = ring
+        self._mp_number = int(mp_number)
+        self.alive = np.ones((self.mp_number, ), dtype=bool)
+        if not alive:
+            self.alive = np.zeros((self.mp_number, ), dtype=bool)
+            mp_number = 1
+        self.ion_element_length = ion_element_length
+        self.track_alive = track_alive
+        self.current = 0
+        self.particles = {
+            "x": np.zeros((mp_number, ), dtype=np.float64),
+            "xp": np.zeros((mp_number, ), dtype=np.float64),
+            "y": np.zeros((mp_number, ), dtype=np.float64),
+            "yp": np.zeros((mp_number, ), dtype=np.float64),
+            "tau": np.zeros((mp_number, ), dtype=np.float64),
+            "delta": np.zeros((mp_number, ), dtype=np.float64),
+        }
+        self.charge_per_mp = 0
+
+    @property
+    def mp_number(self):
+        """Macro-particle number"""
+        return self._mp_number
+
+    @mp_number.setter
+    def mp_number(self, value):
+        self._mp_number = int(value)
+
+    def generate_as_a_distribution(self, electron_bunch):
+        """
+        Generates the particle positions based on a normal distribution, taking distribution parameters from an electron bunch.
+
+        Parameters:
+        ----------
+        electron_bunch : Bunch
+            An instance of the Bunch class representing the electron bunch.
+        """
+        self["x"], self["y"] = (
+            normal(
+                loc=electron_bunch["x"].mean(),
+                scale=(electron_bunch["x"]).std(),
+                size=self.mp_number,
+            ),
+            normal(
+                loc=electron_bunch["y"].mean(),
+                scale=(electron_bunch["y"]).std(),
+                size=self.mp_number,
+            ),
+        )
+        self["xp"], self["yp"], self["delta"] = (
+            np.zeros((self.mp_number, )),
+            np.zeros((self.mp_number, )),
+            np.zeros((self.mp_number, )),
+        )
+
+        self["tau"] = uniform(
+            low=-self.ion_element_length / c,
+            high=self.ion_element_length / c,
+            size=self.mp_number,
+        )
+
+    def generate_from_random_samples(self, electron_bunch):
+        """
+        Generates the particle positions and times based on random samples from electron positions in the bunch.
+
+        Parameters:
+        ----------
+        electron_bunch : Bunch
+            An instance of the Bunch class representing the electron bunch.
+        """
+        self["x"], self["y"] = (
+            choice(electron_bunch["x"], size=self.mp_number),
+            choice(electron_bunch["y"], size=self.mp_number),
+        )
+        self["xp"], self["yp"], self["delta"] = (
+            np.zeros((self.mp_number, )),
+            np.zeros((self.mp_number, )),
+            np.zeros((self.mp_number, )),
+        )
+        self["tau"] = uniform(
+            low=-self.ion_element_length / c,
+            high=self.ion_element_length / c,
+            size=self.mp_number,
+        )
+
+    def __add__(self, new_particles):
+        self.mp_number += new_particles.mp_number
+        for t in ["x", "xp", "y", "yp", "tau", "delta"]:
+            self.particles[t] = np.append(self.particles[t],
+                                          new_particles.particles[t])
+        self.alive = np.append(
+            self.alive, np.ones((new_particles.mp_number, ), dtype=bool))
+        return self
+
+
+class BeamIonElement(Element):
+    """
+    Represents an element for simulating beam-ion interactions.
+
+    Parameters
+    ----------
+    ion_mass : float
+        The mass of the ions in kg.
+    ion_charge : float
+        The charge of the ions in Coulomb.
+    ionization_cross_section : float
+        The cross section of ionization in meters^2.
+    residual_gas_density : float
+        The residual gas density in meters^-3.
+    ring : instance of Synchrotron()
+        The ring.
+    ion_field_model : str
+        The ion field model, the options are 'weak' (acts on each macroparticle), 'strong' (acts on c.m.), 'PIC'.
+        For 'PIC' the PyPIC package is required.
+    electron_field_model : str
+        The electron field model, the options are 'weak', 'strong', 'PIC'.
+    bunch_spacing : float
+        The bunch spacing, the distance between bunches in meters. Used to propagate the ions in between bunches or in the gaps between the bunches.
+    ion_element_length : float
+        The length of the beam-ion interaction region. For example, if only a single interaction point is used this should be equal to ring.L. 
+    x_radius : float
+        The x radius of the aperture.
+    y_radius : float
+        The y radius of the aperture.
+    n_steps : int
+        The number of records in the built-in ion beam monitor. Should be number of turns times number of bunches because the monitor records every turn after each bunch passage.
+    n_ion_macroparticles_per_bunch : int, optional
+        The number of ion macroparticles generated per electron bunch passed. Defaults to 30.
+    use_ion_phase_space_monitor : bool, optional
+        Whether to use the ion phase space monitor.
+    generate_method : str, optional
+        The method to generate the ion macroparticles, the options are 'distribution', 'samples'. Defaults to 'distribution'. 
+        'Distribution' generates a distribution statistically equivalent to the distribution of electrons. 
+        'Samples' generates ions from random samples of electron positions.
+
+    Methods
+    -------
+    __init__(self, ion_mass, ion_charge, ionization_cross_section, residual_gas_density, ring, ion_field_model, electron_field_model, bunch_spacing, ion_element_length, n_steps, x_radius, y_radius, ion_beam_monitor_name=None, use_ion_phase_space_monitor=False, n_ion_macroparticles_per_bunch=30, generate_method='distribution')
+        Initializes the BeamIonElement object.
+    parallel(track)
+        Defines the decorator @parallel to handle tracking of Beam() objects.
+    clear_ions(self)
+        Clear the ion particles in the ion beam.
+    track_ions_in_a_drift(self, drift_length)
+        Tracks the ions in a drift.
+    generate_new_ions(self, electron_bunch)
+        Generate new ions based on the given electron bunch.
+    track(self, electron_bunch)
+        Beam-ion interaction kicks.
+    
+    Raises
+    ------
+    UserWarning
+        If the BeamIonMonitor object is used, the user should call the close() method at the end of tracking.
+    NotImplementedError
+        If the ion phase space monitor is used.
+    """
+
+    def __init__(self,
+                 ion_mass,
+                 ion_charge,
+                 ionization_cross_section,
+                 residual_gas_density,
+                 ring,
+                 ion_field_model,
+                 electron_field_model,
+                 bunch_spacing,
+                 ion_element_length,
+                 n_steps,
+                 x_radius,
+                 y_radius,
+                 ion_beam_monitor_name=None,
+                 use_ion_phase_space_monitor=False,
+                 n_ion_macroparticles_per_bunch=30,
+                 generate_method='distribution'):
+        if use_ion_phase_space_monitor:
+            raise NotImplementedError(
+                "Ion phase space monitor is not implemented.")
+        self.ring = ring
+        self.bunch_spacing = bunch_spacing
+        self.ion_mass = ion_mass
+        self.ionization_cross_section = ionization_cross_section
+        self.residual_gas_density = residual_gas_density
+        self.ion_charge = ion_charge
+        self.electron_field_model = electron_field_model
+        self.ion_field_model = ion_field_model
+        self.ion_element_length = ion_element_length
+        self.generate_method = generate_method
+        self.n_ion_macroparticles_per_bunch = 30
+        self.ion_beam_monitor_name = ion_beam_monitor_name
+        self.ion_beam = IonParticles(
+            mp_number=1,
+            ion_element_length=self.ion_element_length,
+            ring=self.ring)
+        self.ion_beam["x"] = 0
+        self.ion_beam["xp"] = 0
+        self.ion_beam["y"] = 0
+        self.ion_beam["yp"] = 0
+        self.ion_beam["tau"] = 0
+        self.ion_beam["delta"] = 0
+
+        if self.ion_beam_monitor_name:
+            warnings.warn(
+                'BeamIonMonitor.ion_beam_monitor_name.close() should be called at the end of tracking',
+                UserWarning,
+                stacklevel=2)
+            self.beam_monitor = IonMonitor(
+                1,
+                int(n_steps / 10),
+                n_steps,
+                file_name=self.ion_beam_monitor_name)
+
+        self.aperture = IonAperture(X_radius=x_radius, Y_radius=y_radius)
+
+    def parallel(track):
+        """
+        Defines the decorator @parallel which handle the embarrassingly
+        parallel case which happens when there is no bunch to bunch
+        interaction in the tracking routine.
+
+        Adding @Element.parallel allows to write the track method of the
+        Element subclass for a Bunch object instead of a Beam object.
+
+        Parameters
+        ----------
+        track : function, method of an Element subclass
+            track method of an Element subclass which takes a Bunch object as
+            input
+
+        Returns
+        -------
+        track_wrapper: function, method of an Element subclass
+            track method of an Element subclass which takes a Beam object or a
+            Bunch object as input
+        """
+
+        @wraps(track)
+        def track_wrapper(*args, **kwargs):
+            if isinstance(args[1], Beam):
+                if Beam.switch_mpi:
+                    warnings.warn(
+                        'Tracking through beam-ion element is performed sequentially. Bunches are not parallelized.',
+                        UserWarning,
+                        stacklevel=2)
+                self = args[0]
+                beam = args[1]
+                for bunch in beam.bunch_list:
+                    track(self, bunch, *args[2:], **kwargs)
+            else:
+                self = args[0]
+                bunch = args[1]
+                track(self, bunch, *args[2:], **kwargs)
+
+        return track_wrapper
+
+    def clear_ions(self):
+        """
+        Clear the ion particles in the ion beam.
+        """
+        self.ion_beam.particles = IonParticles(
+            mp_number=1, ion_element_length=self.ion_element_length)
+
+    def track_ions_in_a_drift(self, drift_length):
+        """
+        Tracks the ions in a drift.
+    
+        Parameters
+        ----------
+        drift_length : float
+            The drift length in meters.
+        """
+        drifted_ions_x = self.ion_beam["x"] + drift_length * self.ion_beam["xp"]
+        drifted_ions_y = self.ion_beam["y"] + drift_length * self.ion_beam["yp"]
+
+        self.ion_beam["x"] = drifted_ions_x
+        self.ion_beam["y"] = drifted_ions_y
+
+    def _get_efields(self, first_beam, second_beam, field_model):
+        """
+        Calculates the electromagnetic field of the first beam acting on the second beam for a given field model.
+    
+        Parameters
+        ----------
+        first_beam : IonParticles or Bunch
+            The first beam, represented as an instance of IonParticles() or Bunch().
+        second_beam : IonParticles or Bunch
+            The second beam, represented as an instance of IonParticles() or Bunch().
+        field_model : str, optional
+            The field model used for the interaction. Options are 'weak', 'strong', or 'PIC'.
+    
+        Returns
+        -------
+        en_x : numpy.ndarray
+            The x component of the electric field.
+        en_y : numpy.ndarray
+            The y component of the electric field.
+        """
+        assert field_model in [
+            "weak",
+            "strong",
+            "PIC",
+        ], "The implementation for required beam-ion interaction model {:} is not implemented".format(
+            self.interaction_model)
+        sb_mx, sb_stdx = (
+            second_beam["x"].mean(),
+            second_beam["x"].std(),
+        )
+        sb_my, sb_stdy = (
+            second_beam["y"].mean(),
+            second_beam["y"].std(),
+        )
+        if field_model == "weak":
+            en_x, en_y = get_displaced_efield(
+                _efieldn_mit,
+                first_beam["x"],
+                first_beam["y"],
+                sb_stdx,
+                sb_stdy,
+                sb_mx,
+                sb_my,
+            )
+
+        elif field_model == "strong":
+            fb_mx, fb_my = (
+                first_beam["x"].mean(),
+                first_beam["y"].mean(),
+            )
+            en_x, en_y = get_displaced_efield(_efieldn_mit, fb_mx, fb_my,
+                                              sb_stdx, sb_stdy, sb_mx, sb_my)
+
+        elif field_model == "PIC":
+            from PyPIC import FFT_OpenBoundary
+            from PyPIC import geom_impact_ellip as ellipse
+            qe = e
+            Dx = 0.1 * sb_stdx
+            Dy = 0.1 * sb_stdy
+            x_aper = 10 * sb_stdx
+            y_aper = 10 * sb_stdy
+            chamber = ellipse.ellip_cham_geom_object(x_aper=x_aper,
+                                                     y_aper=y_aper)
+            picFFT = FFT_OpenBoundary.FFT_OpenBoundary(
+                x_aper=chamber.x_aper,
+                y_aper=chamber.y_aper,
+                dx=Dx,
+                dy=Dy,
+                fftlib="pyfftw",
+            )
+            nel_part = 0 * second_beam["x"] + 1.0
+            picFFT.scatter(second_beam["x"], second_beam["y"], nel_part)
+            picFFT.solve()
+            en_x, en_y = picFFT.gather(first_beam["x"], first_beam["y"])
+            en_x /= qe * second_beam["x"].shape[0]
+            en_y /= qe * second_beam["x"].shape[0]
+        return en_x, en_y
+
+    def _get_new_beam_momentum(self,
+                               first_beam,
+                               second_beam,
+                               prefactor,
+                               field_model="strong"):
+        """
+        Calculates the new momentum of the first beam due to the interaction with the second beam.
+        
+        Parameters
+        ----------
+        first_beam : IonParticles or Bunch
+            The first beam, represented as an instance of IonParticles() or Bunch().
+        second_beam : IonParticles or Bunch
+            The second beam, represented as an instance of IonParticles() or Bunch().
+        prefactor : float
+            A scaling factor applied to the calculation of the new momentum.
+        field_model : str
+            The field model used for the interaction. Options are 'weak', 'strong', or 'PIC'.
+            Default is "strong".
+        
+        Returns
+        -------
+        new_xp : numpy.ndarray
+            The new x momentum of the first beam.
+        new_yp : numpy.ndarray
+            The new y momentum of the first beam.
+        """
+
+        en_x, en_y = BeamIonElement._get_efields(first_beam,
+                                                 second_beam,
+                                                 field_model=field_model)
+        kicks_x = prefactor * en_x
+        kicks_y = prefactor * en_y
+        new_xp = first_beam["xp"] + kicks_x
+        new_yp = first_beam["yp"] + kicks_y
+        return new_xp, new_yp
+
+    def _update_beam_momentum(self, beam, new_xp, new_yp):
+        beam["xp"] = new_xp
+        beam["yp"] = new_yp
+
+    def generate_new_ions(self, electron_bunch):
+        """
+        Generate new ions based on the given electron bunch.
+
+        Parameters
+        ----------
+        electron_bunch : ElectronBunch
+            The electron bunch used to generate new ions.
+
+        Returns
+        -------
+        None
+        """
+        new_ion_particles = IonParticles(
+            mp_number=self.n_ion_macroparticles_per_bunch,
+            ion_element_length=self.ion_element_length,
+            ring=self.ring,
+        )
+        if self.generate_method == 'distribution':
+            new_ion_particles.generate_as_a_distribution(
+                electron_bunch=electron_bunch)
+        elif self.generate_method == 'samples':
+            new_ion_particles.generate_from_random_samples(
+                electron_bunch=electron_bunch)
+        self.ion_beam += new_ion_particles
+        self.ion_beam.charge_per_mp = (electron_bunch.charge *
+                                       self.ionization_cross_section *
+                                       self.residual_gas_density *
+                                       self.ion_element_length /
+                                       self.n_ion_macroparticles_per_bunch)
+
+    @parallel
+    def track(self, electron_bunch):
+        """
+        Beam-ion interaction kicks.
+
+        Parameters
+        ----------
+        electron_bunch : Bunch() or Beam() class object
+            An electron bunch to be interacted with.
+        """
+
+        self.generate_new_ions(electron_bunch=electron_bunch)
+
+        self.aperture.track(self.ion_beam)
+
+        if self.ion_beam_monitor_name is not None:
+            self.beam_monitor.track(self.ion_beam)
+
+        prefactor_to_ion_field = -self.ion_beam.charge / (self.ring.E0)
+        prefactor_to_electron_field = -electron_bunch.charge * (
+            e / (self.ion_mass * c**2))
+        new_xp_ions, new_yp_ions = self._get_new_beam_momentum(
+            self.ion_beam,
+            electron_bunch,
+            prefactor_to_electron_field,
+            field_model=self.electron_field_model,
+        )
+        new_xp_electrons, new_yp_electrons = self._get_new_beam_momentum(
+            electron_bunch,
+            self.ion_beam,
+            prefactor_to_ion_field,
+            field_model=self.ion_field_model,
+        )
+        self._update_beam_momentum(self.ion_beam, new_xp_ions, new_yp_ions)
+        self._update_beam_momentum(electron_bunch, new_xp_electrons,
+                                   new_yp_electrons)
+        self.track_ions_in_a_drift(drift_length=self.bunch_spacing)
diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index 47cc401cf796021e47a2995279821d58e265caa6..4385b57d2f44e87d82648d2ba97831c22051f581 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -16,7 +16,7 @@ from mbtrack2.tracking.particles import Beam
 
 class Element(metaclass=ABCMeta):
     """
-    Abstract Element class used for subclass inheritance to define all kinds 
+    Abstract Element class used for subclass inheritance to define all kinds
     of objects which intervene in the tracking.
     """
 
@@ -25,7 +25,7 @@ class Element(metaclass=ABCMeta):
         """
         Track a beam object through this Element.
         This method needs to be overloaded in each Element subclass.
-        
+
         Parameters
         ----------
         beam : Beam object
@@ -35,19 +35,19 @@ class Element(metaclass=ABCMeta):
     @staticmethod
     def parallel(track):
         """
-        Defines the decorator @parallel which handle the embarrassingly 
-        parallel case which happens when there is no bunch to bunch 
+        Defines the decorator @parallel which handle the embarrassingly
+        parallel case which happens when there is no bunch to bunch
         interaction in the tracking routine.
-        
-        Adding @Element.parallel allows to write the track method of the 
+
+        Adding @Element.parallel allows to write the track method of the
         Element subclass for a Bunch object instead of a Beam object.
-        
+
         Parameters
         ----------
         track : function, method of an Element subclass
             track method of an Element subclass which takes a Bunch object as
             input
-            
+
         Returns
         -------
         track_wrapper: function, method of an Element subclass
@@ -60,7 +60,7 @@ class Element(metaclass=ABCMeta):
             if isinstance(args[1], Beam):
                 self = args[0]
                 beam = args[1]
-                if (beam.mpi_switch == True):
+                if beam.mpi_switch == True:
                     track(self, beam[beam.mpi.bunch_num], *args[2:], **kwargs)
                 else:
                     for bunch in beam.not_empty:
@@ -76,7 +76,7 @@ class Element(metaclass=ABCMeta):
 class LongitudinalMap(Element):
     """
     Longitudinal map for a single turn in the synchrotron.
-    
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -91,7 +91,7 @@ class LongitudinalMap(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
@@ -103,9 +103,9 @@ class LongitudinalMap(Element):
 
 class SynchrotronRadiation(Element):
     """
-    Element to handle synchrotron radiation, radiation damping and quantum 
+    Element to handle synchrotron radiation, radiation damping and quantum
     excitation, for a single turn in the synchrotron.
-    
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -123,37 +123,34 @@ class SynchrotronRadiation(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
-        if (self.switch[0] == True):
+        if self.switch[0] == True:
             rand = np.random.normal(size=len(bunch))
-            bunch["delta"] = (
-                (1 - 2 * self.ring.T0 / self.ring.tau[2]) * bunch["delta"] +
-                2 * self.ring.sigma_delta *
-                (self.ring.T0 / self.ring.tau[2])**0.5 * rand)
+            bunch["delta"] = (1 - 2 * self.ring.T0 / self.ring.tau[2]) * bunch[
+                "delta"] + 2 * self.ring.sigma_delta * (
+                    self.ring.T0 / self.ring.tau[2])**0.5 * rand
 
-        if (self.switch[1] == True):
+        if self.switch[1] == True:
             rand = np.random.normal(size=len(bunch))
-            bunch["xp"] = (
-                (1 - 2 * self.ring.T0 / self.ring.tau[0]) * bunch["xp"] +
-                2 * self.ring.sigma()[1] *
-                (self.ring.T0 / self.ring.tau[0])**0.5 * rand)
+            bunch["xp"] = (1 - 2 * self.ring.T0 / self.ring.tau[0]
+                           ) * bunch["xp"] + 2 * self.ring.sigma()[1] * (
+                               self.ring.T0 / self.ring.tau[0])**0.5 * rand
 
-        if (self.switch[2] == True):
+        if self.switch[2] == True:
             rand = np.random.normal(size=len(bunch))
-            bunch["yp"] = (
-                (1 - 2 * self.ring.T0 / self.ring.tau[1]) * bunch["yp"] +
-                2 * self.ring.sigma()[3] *
-                (self.ring.T0 / self.ring.tau[1])**0.5 * rand)
+            bunch["yp"] = (1 - 2 * self.ring.T0 / self.ring.tau[1]
+                           ) * bunch["yp"] + 2 * self.ring.sigma()[3] * (
+                               self.ring.T0 / self.ring.tau[1])**0.5 * rand
 
 
 class TransverseMap(Element):
     """
     Transverse map for a single turn in the synchrotron.
-    
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -170,7 +167,7 @@ class TransverseMap(Element):
                 np.poly1d(self.ring.adts[0]),
                 np.poly1d(self.ring.adts[1]),
                 np.poly1d(self.ring.adts[2]),
-                np.poly1d(self.ring.adts[3])
+                np.poly1d(self.ring.adts[3]),
             ]
 
     @Element.parallel
@@ -179,7 +176,7 @@ class TransverseMap(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
@@ -187,23 +184,29 @@ class TransverseMap(Element):
 
         # Compute phase advance which depends on energy via chromaticity and ADTS
         if self.ring.adts is None:
-            phase_advance_x = 2 * np.pi * (self.ring.tune[0] +
-                                           self.ring.chro[0] * bunch["delta"])
-            phase_advance_y = 2 * np.pi * (self.ring.tune[1] +
-                                           self.ring.chro[1] * bunch["delta"])
+            phase_advance_x = (
+                2 * np.pi *
+                (self.ring.tune[0] + self.ring.chro[0] * bunch["delta"]))
+            phase_advance_y = (
+                2 * np.pi *
+                (self.ring.tune[1] + self.ring.chro[1] * bunch["delta"]))
         else:
-            Jx = (self.ring.optics.local_gamma[0] * bunch['x']**2) + \
-                  (2*self.ring.optics.local_alpha[0] * bunch['x']*bunch['xp']) + \
-                  (self.ring.optics.local_beta[0] * bunch['xp']**2)
-            Jy = (self.ring.optics.local_gamma[1] * bunch['y']**2) + \
-                  (2*self.ring.optics.local_alpha[1] * bunch['y']*bunch['yp']) + \
-                  (self.ring.optics.local_beta[1] * bunch['yp']**2)
-            phase_advance_x = 2 * np.pi * (
-                self.ring.tune[0] + self.ring.chro[0] * bunch["delta"] +
-                self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
-            phase_advance_y = 2 * np.pi * (
-                self.ring.tune[1] + self.ring.chro[1] * bunch["delta"] +
-                self.adts_poly[1](Jx) + self.adts_poly[3](Jy))
+            Jx = ((self.ring.optics.local_gamma[0] * bunch["x"]**2) +
+                  (2 * self.ring.optics.local_alpha[0] * bunch["x"] *
+                   bunch["xp"]) +
+                  (self.ring.optics.local_beta[0] * bunch["xp"]**2))
+            Jy = ((self.ring.optics.local_gamma[1] * bunch["y"]**2) +
+                  (2 * self.ring.optics.local_alpha[1] * bunch["y"] *
+                   bunch["yp"]) +
+                  (self.ring.optics.local_beta[1] * bunch["yp"]**2))
+            phase_advance_x = (
+                2 * np.pi *
+                (self.ring.tune[0] + self.ring.chro[0] * bunch["delta"] +
+                 self.adts_poly[0](Jx) + self.adts_poly[2](Jy)))
+            phase_advance_y = (
+                2 * np.pi *
+                (self.ring.tune[1] + self.ring.chro[1] * bunch["delta"] +
+                 self.adts_poly[1](Jx) + self.adts_poly[3](Jy)))
 
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
         matrix = np.zeros((6, 6, len(bunch)), dtype=np.float64)
@@ -232,14 +235,14 @@ class TransverseMap(Element):
         matrix[4, 5, :] = self.dispersion[3]
         matrix[5, 5, :] = 1
 
-        x = matrix[0, 0, :] * bunch["x"] + matrix[
-            0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"]
-        xp = matrix[1, 0, :] * bunch["x"] + matrix[
-            1, 1, :] * bunch["xp"] + matrix[1, 2, :] * bunch["delta"]
-        y = matrix[3, 3, :] * bunch["y"] + matrix[
-            3, 4, :] * bunch["yp"] + matrix[3, 5, :] * bunch["delta"]
-        yp = matrix[4, 3, :] * bunch["y"] + matrix[
-            4, 4, :] * bunch["yp"] + matrix[4, 5, :] * bunch["delta"]
+        x = (matrix[0, 0] * bunch["x"] + matrix[0, 1] * bunch["xp"] +
+             matrix[0, 2] * bunch["delta"])
+        xp = (matrix[1, 0] * bunch["x"] + matrix[1, 1] * bunch["xp"] +
+              matrix[1, 2] * bunch["delta"])
+        y = (matrix[3, 3] * bunch["y"] + matrix[3, 4] * bunch["yp"] +
+             matrix[3, 5] * bunch["delta"])
+        yp = (matrix[4, 3] * bunch["y"] + matrix[4, 4] * bunch["yp"] +
+              matrix[4, 5] * bunch["delta"])
 
         bunch["x"] = x
         bunch["xp"] = xp
@@ -249,14 +252,14 @@ class TransverseMap(Element):
 
 class SkewQuadrupole:
     """
-    Thin skew quadrupole element used to introduce betatron coupling (the 
+    Thin skew quadrupole element used to introduce betatron coupling (the
     length of the quadrupole is neglected).
-    
+
     Parameters
     ----------
     strength : float
         Integrated strength of the skew quadrupole [m].
-        
+
     """
 
     def __init__(self, strength):
@@ -268,19 +271,19 @@ class SkewQuadrupole:
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
         """
 
-        bunch['xp'] = bunch['xp'] - self.strength * bunch['y']
-        bunch['yp'] = bunch['yp'] - self.strength * bunch['x']
+        bunch["xp"] = bunch["xp"] - self.strength * bunch["y"]
+        bunch["yp"] = bunch["yp"] - self.strength * bunch["x"]
 
 
 class TransverseMapSector(Element):
     """
-    Transverse map for a sector of the synchrotron, from an initial 
+    Transverse map for a sector of the synchrotron, from an initial
     position s0 to a final position s1.
 
     Parameters
@@ -293,20 +296,20 @@ class TransverseMapSector(Element):
         Beta Twiss function at the initial location of the sector.
     dispersion0 : array of shape (4,)
         Dispersion function at the initial location of the sector.
-    alpha1 : array of shape (2,)
+    alpha1: array of shape (2,)
         Alpha Twiss function at the final location of the sector.
     beta1 : array of shape (2,)
         Beta Twiss function at the final location of the sector.
     dispersion1 : array of shape (4,)
         Dispersion function at the final location of the sector.
     phase_diff : array of shape (2,)
-        Phase difference between the initial and final location of the 
+        Phase difference between the initial and final location of the
         sector.
     chro_diff : array of shape (2,)
-        Chromaticity difference between the initial and final location of 
+        Chromaticity difference between the initial and final location of
         the sector.
     adts : array of shape (4,), optional
-        Amplitude-dependent tune shift of the sector, see Synchrotron class 
+        Amplitude-dependent tune shift of the sector, see Synchrotron class
         for details. The default is None.
 
     """
@@ -338,7 +341,7 @@ class TransverseMapSector(Element):
                 np.poly1d(adts[0]),
                 np.poly1d(adts[1]),
                 np.poly1d(adts[2]),
-                np.poly1d(adts[3])
+                np.poly1d(adts[3]),
             ]
         else:
             self.adts_poly = None
@@ -349,7 +352,7 @@ class TransverseMapSector(Element):
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
         @Element.parallel is used to handle Beam objects.
-        
+
         Parameters
         ----------
         bunch : Bunch or Beam object
@@ -357,23 +360,27 @@ class TransverseMapSector(Element):
 
         # Compute phase advance which depends on energy via chromaticity and ADTS
         if self.adts_poly is None:
-            phase_advance_x = 2 * np.pi * (self.tune_diff[0] +
-                                           self.chro_diff[0] * bunch["delta"])
-            phase_advance_y = 2 * np.pi * (self.tune_diff[1] +
-                                           self.chro_diff[1] * bunch["delta"])
+            phase_advance_x = (
+                2 * np.pi *
+                (self.tune_diff[0] + self.chro_diff[0] * bunch["delta"]))
+            phase_advance_y = (
+                2 * np.pi *
+                (self.tune_diff[1] + self.chro_diff[1] * bunch["delta"]))
         else:
-            Jx = (self.gamma0[0] * bunch['x']**2) + \
-                  (2*self.alpha0[0] * bunch['x']*self['xp']) + \
-                  (self.beta0[0] * bunch['xp']**2)
-            Jy = (self.gamma0[1] * bunch['y']**2) + \
-                  (2*self.alpha0[1] * bunch['y']*bunch['yp']) + \
-                  (self.beta0[1] * bunch['yp']**2)
-            phase_advance_x = 2 * np.pi * (
-                self.tune_diff[0] + self.chro_diff[0] * bunch["delta"] +
-                self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
-            phase_advance_y = 2 * np.pi * (
-                self.tune_diff[1] + self.chro_diff[1] * bunch["delta"] +
-                self.adts_poly[1](Jx) + self.adts_poly[3](Jy))
+            Jx = ((self.gamma0[0] * bunch["x"]**2) +
+                  (2 * self.alpha0[0] * bunch["x"] * self["xp"]) +
+                  (self.beta0[0] * bunch["xp"]**2))
+            Jy = ((self.gamma0[1] * bunch["y"]**2) +
+                  (2 * self.alpha0[1] * bunch["y"] * bunch["yp"]) +
+                  (self.beta0[1] * bunch["yp"]**2))
+            phase_advance_x = (
+                2 * np.pi *
+                (self.tune_diff[0] + self.chro_diff[0] * bunch["delta"] +
+                 self.adts_poly[0](Jx) + self.adts_poly[2](Jy)))
+            phase_advance_y = (
+                2 * np.pi *
+                (self.tune_diff[1] + self.chro_diff[1] * bunch["delta"] +
+                 self.adts_poly[1](Jx) + self.adts_poly[3](Jy)))
 
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
         matrix = np.zeros((6, 6, len(bunch)))
@@ -383,18 +390,18 @@ class TransverseMapSector(Element):
             np.cos(phase_advance_x) + self.alpha0[0] * np.sin(phase_advance_x))
         matrix[0, 1, :] = np.sqrt(
             self.beta0[0] * self.beta1[0]) * np.sin(phase_advance_x)
-        matrix[0, 2, :] = self.dispersion1[0] - matrix[
-            0, 0, :] * self.dispersion0[0] - matrix[0,
-                                                    1, :] * self.dispersion0[1]
+        matrix[0, 2, :] = (self.dispersion1[0] -
+                           matrix[0, 0, :] * self.dispersion0[0] -
+                           matrix[0, 1, :] * self.dispersion0[1])
         matrix[1, 0, :] = (
             (self.alpha0[0] - self.alpha1[0]) * np.cos(phase_advance_x) -
             (1 + self.alpha0[0] * self.alpha1[0]) *
             np.sin(phase_advance_x)) / np.sqrt(self.beta0[0] * self.beta1[0])
         matrix[1, 1, :] = np.sqrt(self.beta0[0] / self.beta1[0]) * (
             np.cos(phase_advance_x) - self.alpha1[0] * np.sin(phase_advance_x))
-        matrix[1, 2, :] = self.dispersion1[1] - matrix[
-            1, 0, :] * self.dispersion0[0] - matrix[1,
-                                                    1, :] * self.dispersion0[1]
+        matrix[1, 2, :] = (self.dispersion1[1] -
+                           matrix[1, 0, :] * self.dispersion0[0] -
+                           matrix[1, 1, :] * self.dispersion0[1])
         matrix[2, 2, :] = 1
 
         # Vertical
@@ -402,28 +409,28 @@ class TransverseMapSector(Element):
             np.cos(phase_advance_y) + self.alpha0[1] * np.sin(phase_advance_y))
         matrix[3, 4, :] = np.sqrt(
             self.beta0[1] * self.beta1[1]) * np.sin(phase_advance_y)
-        matrix[3, 5, :] = self.dispersion1[2] - matrix[
-            3, 3, :] * self.dispersion0[2] - matrix[3,
-                                                    4, :] * self.dispersion0[3]
+        matrix[3, 5, :] = (self.dispersion1[2] -
+                           matrix[3, 3, :] * self.dispersion0[2] -
+                           matrix[3, 4, :] * self.dispersion0[3])
         matrix[4, 3, :] = (
             (self.alpha0[1] - self.alpha1[1]) * np.cos(phase_advance_y) -
             (1 + self.alpha0[1] * self.alpha1[1]) *
             np.sin(phase_advance_y)) / np.sqrt(self.beta0[1] * self.beta1[1])
         matrix[4, 4, :] = np.sqrt(self.beta0[1] / self.beta1[1]) * (
             np.cos(phase_advance_y) - self.alpha1[1] * np.sin(phase_advance_y))
-        matrix[4, 5, :] = self.dispersion1[3] - matrix[
-            4, 3, :] * self.dispersion0[2] - matrix[4,
-                                                    4, :] * self.dispersion0[3]
+        matrix[4, 5, :] = (self.dispersion1[3] -
+                           matrix[4, 3, :] * self.dispersion0[2] -
+                           matrix[4, 4, :] * self.dispersion0[3])
         matrix[5, 5, :] = 1
 
-        x = matrix[0, 0, :] * bunch["x"] + matrix[
-            0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"]
-        xp = matrix[1, 0, :] * bunch["x"] + matrix[
-            1, 1, :] * bunch["xp"] + matrix[1, 2, :] * bunch["delta"]
-        y = matrix[3, 3, :] * bunch["y"] + matrix[
-            3, 4, :] * bunch["yp"] + matrix[3, 5, :] * bunch["delta"]
-        yp = matrix[4, 3, :] * bunch["y"] + matrix[
-            4, 4, :] * bunch["yp"] + matrix[4, 5, :] * bunch["delta"]
+        x = (matrix[0, 0, :] * bunch["x"] + matrix[0, 1, :] * bunch["xp"] +
+             matrix[0, 2, :] * bunch["delta"])
+        xp = (matrix[1, 0, :] * bunch["x"] + matrix[1, 1, :] * bunch["xp"] +
+              matrix[1, 2, :] * bunch["delta"])
+        y = (matrix[3, 3, :] * bunch["y"] + matrix[3, 4, :] * bunch["yp"] +
+             matrix[3, 5, :] * bunch["delta"])
+        yp = (matrix[4, 3, :] * bunch["y"] + matrix[4, 4, :] * bunch["yp"] +
+              matrix[4, 5, :] * bunch["delta"])
 
         bunch["x"] = x
         bunch["xp"] = xp
@@ -434,19 +441,22 @@ class TransverseMapSector(Element):
 def transverse_map_sector_generator(ring, positions):
     """
     Convenience function which generate a list of TransverseMapSector elements
-    from a ring with an AT lattice.
-    
-    Tracking through all the sectors is equivalent to a full turn (and thus to 
+    from a ring:
+        - if an AT lattice is loaded, the optics functions and chromaticity is
+        computed at the given positions.
+        - if no AT lattice is loaded, the local optics are used everywhere.
+
+    Tracking through all the sectors is equivalent to a full turn (and thus to
     the TransverseMap object).
 
     Parameters
     ----------
     ring : Synchrotron object
-        Ring parameters, must .
+        Ring parameters.
     positions : array
         List of longitudinal positions in [m] to use as starting and end points
         of the TransverseMapSector elements.
-        The array should contain the initial position (s=0) but not the end 
+        The array should contain the initial position (s=0) but not the end
         position (s=ring.L), so like position = np.array([0, pos1, pos2, ...]).
 
     Returns
@@ -455,56 +465,76 @@ def transverse_map_sector_generator(ring, positions):
         List of TransverseMapSector elements.
 
     """
-    import at
-
-    def _compute_chro(ring, pos, dp=1e-4):
-        lat = deepcopy(ring.optics.lattice)
-        lat.append(at.Marker("END"))
-        N = len(lat)
-        refpts = np.arange(N)
-        *elem_neg_dp, = at.linopt2(lat, refpts=refpts, dp=-dp)
-        *elem_pos_dp, = at.linopt2(lat, refpts=refpts, dp=dp)
-
-        s = elem_neg_dp[2]["s_pos"]
-        mux0 = elem_neg_dp[2]['mu'][:, 0]
-        mux1 = elem_pos_dp[2]['mu'][:, 0]
-        muy0 = elem_neg_dp[2]['mu'][:, 1]
-        muy1 = elem_pos_dp[2]['mu'][:, 1]
-
-        Chrox = (mux1-mux0) / (2*dp) / 2 / np.pi
-        Chroy = (muy1-muy0) / (2*dp) / 2 / np.pi
-        chrox = np.interp(pos, s, Chrox)
-        chroy = np.interp(pos, s, Chroy)
-
-        return np.array([chrox, chroy])
-
-    if ring.optics.use_local_values:
-        raise ValueError(
-            "The Synchrotron object must be loaded from an AT lattice")
-
     N_sec = len(positions)
     sectors = []
-    for i in range(N_sec):
-        alpha0 = ring.optics.alpha(positions[i])
-        beta0 = ring.optics.beta(positions[i])
-        dispersion0 = ring.optics.dispersion(positions[i])
-        mu0 = ring.optics.mu(positions[i])
-        chro0 = _compute_chro(ring, positions[i])
-        if i != (N_sec - 1):
-            alpha1 = ring.optics.alpha(positions[i + 1])
-            beta1 = ring.optics.beta(positions[i + 1])
-            dispersion1 = ring.optics.dispersion(positions[i + 1])
-            mu1 = ring.optics.mu(positions[i + 1])
-            chro1 = _compute_chro(ring, positions[i + 1])
-        else:
-            alpha1 = ring.optics.alpha(positions[0])
-            beta1 = ring.optics.beta(positions[0])
-            dispersion1 = ring.optics.dispersion(positions[0])
-            mu1 = ring.optics.mu(ring.L)
-            chro1 = _compute_chro(ring, ring.L)
-        phase_diff = mu1 - mu0
-        chro_diff = chro1 - chro0
-        sectors.append(
-            TransverseMapSector(ring, alpha0, beta0, dispersion0, alpha1,
-                                beta1, dispersion1, phase_diff, chro_diff))
+    if ring.optics.use_local_values:
+        for i in range(N_sec):
+            sectors.append(
+                TransverseMapSector(
+                    ring,
+                    ring.optics.local_alpha,
+                    ring.optics.local_beta,
+                    ring.optics.local_dispersion,
+                    ring.optics.local_alpha,
+                    ring.optics.local_beta,
+                    ring.optics.local_dispersion,
+                    ring.tune / N_sec,
+                    ring.chro / N_sec,
+                ))
+    else:
+        import at
+
+        def _compute_chro(ring, pos, dp=1e-4):
+            lat = deepcopy(ring.optics.lattice)
+            lat.append(at.Marker("END"))
+            N = len(lat)
+            refpts = np.arange(N)
+            (*elem_neg_dp, ) = at.linopt2(lat, refpts=refpts, dp=-dp)
+            (*elem_pos_dp, ) = at.linopt2(lat, refpts=refpts, dp=dp)
+
+            s = elem_neg_dp[2]["s_pos"]
+            mux0 = elem_neg_dp[2]["mu"][:, 0]
+            mux1 = elem_pos_dp[2]["mu"][:, 0]
+            muy0 = elem_neg_dp[2]["mu"][:, 1]
+            muy1 = elem_pos_dp[2]["mu"][:, 1]
+
+            Chrox = (mux1-mux0) / (2*dp) / 2 / np.pi
+            Chroy = (muy1-muy0) / (2*dp) / 2 / np.pi
+            chrox = np.interp(pos, s, Chrox)
+            chroy = np.interp(pos, s, Chroy)
+
+            return np.array([chrox, chroy])
+
+        for i in range(N_sec):
+            alpha0 = ring.optics.alpha(positions[i])
+            beta0 = ring.optics.beta(positions[i])
+            dispersion0 = ring.optics.dispersion(positions[i])
+            mu0 = ring.optics.mu(positions[i])
+            chro0 = _compute_chro(ring, positions[i])
+            if i != (N_sec - 1):
+                alpha1 = ring.optics.alpha(positions[i + 1])
+                beta1 = ring.optics.beta(positions[i + 1])
+                dispersion1 = ring.optics.dispersion(positions[i + 1])
+                mu1 = ring.optics.mu(positions[i + 1])
+                chro1 = _compute_chro(ring, positions[i + 1])
+            else:
+                alpha1 = ring.optics.alpha(positions[0])
+                beta1 = ring.optics.beta(positions[0])
+                dispersion1 = ring.optics.dispersion(positions[0])
+                mu1 = ring.optics.mu(ring.L)
+                chro1 = _compute_chro(ring, ring.L)
+            phase_diff = mu1 - mu0
+            chro_diff = chro1 - chro0
+            sectors.append(
+                TransverseMapSector(
+                    ring,
+                    alpha0,
+                    beta0,
+                    dispersion0,
+                    alpha1,
+                    beta1,
+                    dispersion1,
+                    phase_diff,
+                    chro_diff,
+                ))
     return sectors
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index 4c3934b75a81798c4a429d89e3b1ef5b6a8ab330..5b973d14ecb43a02663fde0238d9da86c68d2cd6 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -49,10 +49,17 @@ class Proton(Particle):
         super().__init__(m_p, e)
 
 
+class Ion(Particle):
+    """Define an ion"""
+
+    def __init__(self, ion_mass, ion_charge):
+        super().__init(ion_mass * m_p, ion_charge * e)
+
+
 class Bunch:
     """
     Define a bunch object.
-    
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -69,13 +76,13 @@ class Bunch:
     load_from_file : str, optional
         Name of bunch save file generated by save method.
         If None, the bunch is initialized using the input parameters.
-        Otherwise, the bunch from the file is loaded and the other inputs are 
+        Otherwise, the bunch from the file is loaded and the other inputs are
         ignored.
         Default is None.
     load_suffix : str or int, optional
         Suffix to the group name used to load the data from the HDF5 file.
         Default is None.
-        
+
     Attributes
     ----------
     mp_number : int
@@ -95,7 +102,7 @@ class Bunch:
     mean : array of shape (6,)
         Mean position of alive particles for each coordinates.
     std : array of shape (6,)
-        Standard deviation of the position of alive particles for each 
+        Standard deviation of the position of alive particles for each
         coordinates.
     emit : array of shape (3,)
         Bunch emittance for each plane [1].
@@ -113,14 +120,14 @@ class Bunch:
     plot_phasespace(x_var="tau", y_var="delta", plot_type="j")
         Plot phase space.
     save(file_name)
-        Save bunch object data (6D phase space, current, and state) in an HDF5 
+        Save bunch object data (6D phase space, current, and state) in an HDF5
         file format.
     load(file_name)
         Load data from a HDF5 file recorded by Bunch save method.
-    
+
     References
     ----------
-    [1] Wiedemann, H. (2015). Particle accelerator physics. 4th edition. 
+    [1] Wiedemann, H. (2015). Particle accelerator physics. 4th edition.
     Springer, Eq.(8.39) of p224.
     """
 
@@ -246,7 +253,7 @@ class Bunch:
     @property
     def std(self):
         """
-        Return the standard deviation of the position of alive 
+        Return the standard deviation of the position of alive
         particles for each coordinates.
         """
         std = [[self[name].std()] for name in self]
@@ -293,15 +300,15 @@ class Bunch:
         Return the average Courant-Snyder invariant of each plane.
 
         """
-        Jx = (self.ring.optics.local_gamma[0] * self['x']**2) + \
-              (2*self.ring.optics.local_alpha[0] * self['x'])*self['xp'] + \
-              (self.ring.optics.local_beta[0] * self['xp']**2)
-        Jy = (self.ring.optics.local_gamma[1] * self['y']**2) + \
-              (2*self.ring.optics.local_alpha[1] * self['y']*self['yp']) + \
-              (self.ring.optics.local_beta[1] * self['yp']**2)
-        Js = (self.ring.long_gamma * self['tau']**2) + \
-              (2*self.ring.long_alpha * self['tau']*self['delta']) + \
-              (self.ring.long_beta * self['delta']**2)
+        Jx = ((self.ring.optics.local_gamma[0] * self["x"]**2) +
+              (2 * self.ring.optics.local_alpha[0] * self["x"]) * self["xp"] +
+              (self.ring.optics.local_beta[0] * self["xp"]**2))
+        Jy = ((self.ring.optics.local_gamma[1] * self["y"]**2) +
+              (2 * self.ring.optics.local_alpha[1] * self["y"] * self["yp"]) +
+              (self.ring.optics.local_beta[1] * self["yp"]**2))
+        Js = ((self.ring.long_gamma * self["tau"]**2) +
+              (2 * self.ring.long_alpha * self["tau"] * self["delta"]) +
+              (self.ring.long_beta * self["delta"]**2))
         return np.array((np.mean(Jx), np.mean(Jy), np.mean(Js)))
 
     def init_gaussian(self, cov=None, mean=None, **kwargs):
@@ -309,17 +316,17 @@ class Bunch:
         Initialize bunch particles with 6D gaussian phase space.
         Covariance matrix is taken from [1] and dispersion is added following
         the method explained in [2].
-                
+
         Parameters
         ----------
         cov : (6,6) array, optional
             Covariance matrix of the bunch distribution
         mean : (6,) array, optional
             Mean of the bunch distribution
-        
+
         References
         ----------
-        [1] Wiedemann, H. (2015). Particle accelerator physics. 4th 
+        [1] Wiedemann, H. (2015). Particle accelerator physics. 4th
         edition. Springer, Eq.(8.38) of p223.
         [2] http://www.pp.rhul.ac.uk/bdsim/manual-develop/dev_beamgeneration.html
 
@@ -333,10 +340,12 @@ class Bunch:
             optics = kwargs.get("optics", self.ring.optics)
 
             cov = np.zeros((6, 6))
-            cov[0, 0] = self.ring.emit[0] * optics.local_beta[0] + (
-                optics.local_dispersion[0] * self.ring.sigma_delta)**2
-            cov[1, 1] = self.ring.emit[0] * optics.local_gamma[0] + (
-                optics.local_dispersion[1] * self.ring.sigma_delta)**2
+            cov[0,
+                0] = (self.ring.emit[0] * optics.local_beta[0] +
+                      (optics.local_dispersion[0] * self.ring.sigma_delta)**2)
+            cov[1,
+                1] = (self.ring.emit[0] * optics.local_gamma[0] +
+                      (optics.local_dispersion[1] * self.ring.sigma_delta)**2)
             cov[0, 1] = -1 * self.ring.emit[0] * optics.local_alpha[0] + (
                 optics.local_dispersion[0] * optics.local_dispersion[1] *
                 self.ring.sigma_delta**2)
@@ -347,10 +356,12 @@ class Bunch:
             cov[5, 0] = optics.local_dispersion[0] * self.ring.sigma_delta**2
             cov[1, 5] = optics.local_dispersion[1] * self.ring.sigma_delta**2
             cov[5, 1] = optics.local_dispersion[1] * self.ring.sigma_delta**2
-            cov[2, 2] = self.ring.emit[1] * optics.local_beta[1] + (
-                optics.local_dispersion[2] * self.ring.sigma_delta)**2
-            cov[3, 3] = self.ring.emit[1] * optics.local_gamma[1] + (
-                optics.local_dispersion[3] * self.ring.sigma_delta)**2
+            cov[2,
+                2] = (self.ring.emit[1] * optics.local_beta[1] +
+                      (optics.local_dispersion[2] * self.ring.sigma_delta)**2)
+            cov[3,
+                3] = (self.ring.emit[1] * optics.local_gamma[1] +
+                      (optics.local_dispersion[3] * self.ring.sigma_delta)**2)
             cov[2, 3] = -1 * self.ring.emit[1] * optics.local_alpha[1] + (
                 optics.local_dispersion[2] * optics.local_dispersion[3] *
                 self.ring.sigma_delta**2)
@@ -402,7 +413,7 @@ class Bunch:
 
         bins = np.linspace(bin_min, bin_max, n_bin)
         center = (bins[1:] + bins[:-1]) / 2
-        sorted_index = np.searchsorted(bins, self[dimension], side='left')
+        sorted_index = np.searchsorted(bins, self[dimension], side="left")
         sorted_index -= 1
         profile = np.bincount(sorted_index, minlength=n_bin - 1)
 
@@ -472,9 +483,9 @@ class Bunch:
 
     def save(self, file_name, suffix=None, mpi_comm=None):
         """
-        Save bunch object data (6D phase space, current, and state) in an HDF5 
+        Save bunch object data (6D phase space, current, and state) in an HDF5
         file format.
-        
+
         The output file is named as "<file_name>.hdf5".
 
         Parameters
@@ -492,11 +503,13 @@ class Bunch:
         if mpi_comm is None:
             f = hp.File(file_name + ".hdf5", "a", libver="earliest")
         else:
-            f = hp.File(file_name + ".hdf5",
-                        "a",
-                        libver='earliest',
-                        driver='mpio',
-                        comm=mpi_comm)
+            f = hp.File(
+                file_name + ".hdf5",
+                "a",
+                libver="earliest",
+                driver="mpio",
+                comm=mpi_comm,
+            )
 
         if suffix is None:
             group_name = "Bunch"
@@ -539,14 +552,14 @@ class Bunch:
         else:
             group_name = "Bunch_" + str(suffix)
 
-        self.mp_number = len(f[group_name]['alive'][:])
+        self.mp_number = len(f[group_name]["alive"][:])
 
         for i, dim in enumerate(self):
             self.particles[dim] = f[group_name]["phasespace"][:, i]
 
-        self.alive = f[group_name]['alive'][:]
-        if f[group_name]['current'][:][0] != 0:
-            self.current = f[group_name]['current'][:][0]
+        self.alive = f[group_name]["alive"][:]
+        if f[group_name]["current"][:][0] != 0:
+            self.current = f[group_name]["current"][:][0]
         else:
             self.charge_per_mp = 0
 
@@ -557,8 +570,8 @@ class Bunch:
 
 class Beam:
     """
-    Define a Beam object composed of several Bunch objects. 
-    
+    Define a Beam object composed of several Bunch objects.
+
     Parameters
     ----------
     ring : Synchrotron object
@@ -583,7 +596,7 @@ class Beam:
     bunch_mean : array of shape (6, ring.h)
         Mean position of alive particles for each bunch
     bunch_std : array of shape (6, ring.h)
-        Standard deviation of the position of alive particles for each bunch        
+        Standard deviation of the position of alive particles for each bunch
     bunch_emit : array of shape (6, ring.h)
         Bunch emittance of alive particles for each bunch
     mpi : Mpi object
@@ -592,17 +605,17 @@ class Beam:
         mpi_init() and mpi_close()
     bunch_index : array of shape (len(self,))
         Return an array with the positions of the non-empty bunches
-        
+
     Methods
     ------
     init_beam(filling_pattern, current_per_bunch=1e-3, mp_per_bunch=1e3)
-        Initialize beam with a given filling pattern and marco-particle number 
+        Initialize beam with a given filling pattern and marco-particle number
         per bunch. Then initialize the different bunches with a 6D gaussian
         phase space.
     mpi_init()
         Switch on MPI parallelisation and initialise a Mpi object
     mpi_gather()
-        Gather beam, all bunches of the different processors are sent to 
+        Gather beam, all bunches of the different processors are sent to
         all processors. Rather slow
     mpi_share_distributions()
         Compute the bunch profile and share it between the different bunches.
@@ -622,7 +635,7 @@ class Beam:
         if bunch_list is None:
             self.init_beam(np.zeros((self.ring.h, 1), dtype=bool))
         else:
-            if (len(bunch_list) != self.ring.h):
+            if len(bunch_list) != self.ring.h:
                 raise ValueError(("The length of the bunch list is {} ".format(
                     len(bunch_list)) + "but should be {}".format(self.ring.h)))
             self.bunch_list = bunch_list
@@ -661,7 +674,7 @@ class Beam:
 
     @property
     def distance_between_bunches(self):
-        """Return an array which contains the distance to the next bunch in 
+        """Return an array which contains the distance to the next bunch in
         units of the RF period (ring.T1)"""
         return self._distance_between_bunches
 
@@ -696,25 +709,27 @@ class Beam:
 
         self._distance_between_bunches = distance
 
-    def init_beam(self,
-                  filling_pattern,
-                  current_per_bunch=1e-3,
-                  mp_per_bunch=1e3,
-                  track_alive=True,
-                  mpi=False):
+    def init_beam(
+        self,
+        filling_pattern,
+        current_per_bunch=1e-3,
+        mp_per_bunch=1e3,
+        track_alive=True,
+        mpi=False,
+    ):
         """
-        Initialize beam with a given filling pattern and marco-particle number 
+        Initialize beam with a given filling pattern and marco-particle number
         per bunch. Then initialize the different bunches with a 6D gaussian
         phase space.
-        
-        If the filling pattern is an array of bool then the current per bunch 
+
+        If the filling pattern is an array of bool then the current per bunch
         is uniform, else the filling pattern can be an array with the current
         in each bunch.
-        
+
         Parameters
         ----------
         filling_pattern : numpy array or list of length ring.h
-            Filling pattern of the beam, can be a list or an array of bool, 
+            Filling pattern of the beam, can be a list or an array of bool,
             then current_per_bunch is used. Or can be an array with the current
             in each bunch.
         current_per_bunch : float, optional
@@ -730,7 +745,7 @@ class Beam:
             other bunches are initialized with a single marco-particle.
         """
 
-        if (len(filling_pattern) != self.ring.h):
+        if len(filling_pattern) != self.ring.h:
             raise ValueError(("The length of filling pattern is {} ".format(
                 len(filling_pattern)) +
                               "but should be {}".format(self.ring.h)))
@@ -839,7 +854,7 @@ class Beam:
 
     @property
     def bunch_std(self):
-        """Return an array with the standard deviation of the position of alive 
+        """Return an array with the standard deviation of the position of alive
         particles for each bunches"""
         bunch_std = np.zeros((6, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
@@ -859,7 +874,7 @@ class Beam:
 
     @property
     def bunch_cs(self):
-        """Return an array with the average Courant-Snyder invariant for each 
+        """Return an array with the average Courant-Snyder invariant for each
         bunch"""
         bunch_cs = np.zeros((3, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
@@ -870,14 +885,15 @@ class Beam:
     def mpi_init(self):
         """Switch on MPI parallelisation and initialise a Mpi object"""
         from mbtrack2.tracking.parallel import Mpi
+
         self.mpi = Mpi(self.filling_pattern)
         self.mpi_switch = True
 
     def mpi_gather(self):
-        """Gather beam, all bunches of the different processors are sent to 
+        """Gather beam, all bunches of the different processors are sent to
         all processors. Rather slow"""
 
-        if (self.mpi_switch == False):
+        if self.mpi_switch == False:
             print("Error, mpi is not initialised.")
 
         bunch = self[self.mpi.bunch_num]
@@ -897,17 +913,17 @@ class Beam:
 
         Parameters
         ----------
-        var : str {"bunch_current", "bunch_charge", "bunch_particle", 
+        var : str {"bunch_current", "bunch_charge", "bunch_particle",
                    "bunch_mean", "bunch_std", "bunch_emit"}
             Variable to be plotted.
         option : str, optional
-            If var is "bunch_mean", "bunch_std", or "bunch_emit, option needs 
+            If var is "bunch_mean", "bunch_std", or "bunch_emit, option needs
             to be specified.
-            For "bunch_mean" and "bunch_std", 
+            For "bunch_mean" and "bunch_std",
                 option = {"x","xp","y","yp","tau","delta"}.
             For "bunch_emit", option = {"x","y","s"}.
             The default is None.
-            
+
         Return
         ------
         fig : Figure
@@ -920,7 +936,7 @@ class Beam:
             "bunch_particle": self.bunch_particle,
             "bunch_mean": self.bunch_mean,
             "bunch_std": self.bunch_std,
-            "bunch_emit": self.bunch_emit
+            "bunch_emit": self.bunch_emit,
         }
 
         fig, ax = plt.subplots()
@@ -936,12 +952,20 @@ class Beam:
             }
             scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
             label_mean = [
-                "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                "$\\tau$ (ps)", "$\\delta$"
+                "x (um)",
+                "x' ($\\mu$rad)",
+                "y (um)",
+                "y' ($\\mu$rad)",
+                "$\\tau$ (ps)",
+                "$\\delta$",
             ]
             label_std = [
-                "std x (um)", "std x' ($\\mu$rad)", "std y (um)",
-                "std y' ($\\mu$rad)", "std $\\tau$ (ps)", "std $\\delta$"
+                "std x (um)",
+                "std x' ($\\mu$rad)",
+                "std y (um)",
+                "std y' ($\\mu$rad)",
+                "std $\\tau$ (ps)",
+                "std $\\delta$",
             ]
 
             y_axis = var_dict[var][value_dict[option]]
@@ -952,7 +976,7 @@ class Beam:
 
             ax.plot(np.arange(len(self.filling_pattern)),
                     y_axis * scale[value_dict[option]])
-            ax.set_xlabel('bunch number')
+            ax.set_xlabel("bunch number")
             if var == "bunch_mean":
                 ax.set_ylabel(label_mean[value_dict[option]])
             else:
@@ -971,11 +995,14 @@ class Beam:
             ax.plot(np.arange(len(self.filling_pattern)),
                     y_axis * scale[value_dict[option]])
 
-            if option == "x": label_y = "hor. emittance (nm.rad)"
-            elif option == "y": label_y = "ver. emittance (nm.rad)"
-            elif option == "s": label_y = "long. emittance (fm.rad)"
+            if option == "x":
+                label_y = "hor. emittance (nm.rad)"
+            elif option == "y":
+                label_y = "ver. emittance (nm.rad)"
+            elif option == "s":
+                label_y = "long. emittance (fm.rad)"
 
-            ax.set_xlabel('bunch number')
+            ax.set_xlabel("bunch number")
             ax.set_ylabel(label_y)
 
         elif var == "bunch_current" or var == "bunch_charge" or var == "bunch_particle":
@@ -987,11 +1014,14 @@ class Beam:
 
             ax.plot(np.arange(len(self.filling_pattern)),
                     var_dict[var] * scale[var])
-            ax.set_xlabel('bunch number')
+            ax.set_xlabel("bunch number")
 
-            if var == "bunch_current": label_y = "bunch current (mA)"
-            elif var == "bunch_charge": label_y = "bunch chagre (nC)"
-            else: label_y = "number of particles"
+            if var == "bunch_current":
+                label_y = "bunch current (mA)"
+            elif var == "bunch_charge":
+                label_y = "bunch chagre (nC)"
+            else:
+                label_y = "number of particles"
 
             ax.set_ylabel(label_y)
 
@@ -1004,7 +1034,7 @@ class Beam:
     def save(self, file_name):
         """
         Save beam object data in an HDF5 file format.
-        
+
         The output file is named as "<file_name>.hdf5".
 
         Parameters
@@ -1025,11 +1055,13 @@ class Beam:
                         mp_number = None
                         mp_number = self.mpi.comm.bcast(
                             mp_number, root=self.mpi.bunch_to_rank(i))
-                        f = hp.File(file_name + ".hdf5",
-                                    "a",
-                                    libver='earliest',
-                                    driver='mpio',
-                                    comm=self.mpi.comm)
+                        f = hp.File(
+                            file_name + ".hdf5",
+                            "a",
+                            libver="earliest",
+                            driver="mpio",
+                            comm=self.mpi.comm,
+                        )
                         group_name = "Bunch_" + str(i)
                         g = f.create_group(group_name)
                         g.create_dataset("alive", (mp_number, ), dtype=bool)
@@ -1065,7 +1097,7 @@ class Beam:
             f = hp.File(file_name, "r")
             filling_pattern = []
             for i in range(self.ring.h):
-                current = f["Bunch_" + str(i)]['current'][:][0]
+                current = f["Bunch_" + str(i)]["current"][:][0]
                 filling_pattern.append(current)
 
             self.init_beam(filling_pattern,
diff --git a/mbtrack2/tracking/particles_electromagnetic_fields.py b/mbtrack2/tracking/particles_electromagnetic_fields.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9577b15b251745be03f0048f6e8e497b88d49a
--- /dev/null
+++ b/mbtrack2/tracking/particles_electromagnetic_fields.py
@@ -0,0 +1,209 @@
+"""
+A package dealing with particles electromagnetic fields. 
+For example, it can be applied to space charge, beam-beam force, electron lenses or beam-ion instabilities.
+This is largely adapted from a fork of PyHEADTAIL https://github.com/gubaidulinvadim/PyHEADTAIL.
+Only the fastest Fadeeva implementation of the error function is used here.
+See  Oeftiger, A., de Maria, R., Deniau, L., Li, K., McIntosh, E., Moneta, L., Hegglin, S., Aviral, A. (2016).
+Review of CPU and GPU Fadeeva Implementations. https://cds.cern.ch/record/2207430/files/wepoy044.pdf 
+"""
+
+from functools import wraps
+
+import numpy as np
+from scipy.constants import epsilon_0, pi
+from scipy.special import wofz as _scipy_wofz
+
+
+def _wofz(x, y):
+    """
+    Compute the Faddeeva function w(z) = exp(-z^2) * erfc(-i*z).
+
+    Parameters
+    ----------
+    x : float
+        Real part of the argument.
+    y : float
+        Imaginary part of the argument.
+
+    Returns
+    -------
+    tuple
+        Real and imaginary parts of the Faddeeva function.
+    """
+    res = _scipy_wofz(x + 1j*y)
+    return res.real, res.imag
+
+
+def _sqrt_sig(sig_x, sig_y):
+    """
+    Compute the square root of the difference between the squared transverse rms and vertical rms.
+
+    Parameters
+    ----------
+    sig_x : float
+        Transverse rms of the distribution.
+    sig_y : float
+        Vertical rms of the distribution.
+
+    Returns
+    -------
+    float
+        Square root of the difference between the squared transverse rms and vertical rms.
+    """
+    return np.sqrt(2 * (sig_x*sig_x - sig_y*sig_y))
+
+
+def _efieldn_mit(x, y, sig_x, sig_y):
+    """
+    Returns electromagnetic fields as E_x/Q, E_y/Q in (V/m/Coulomb).
+
+    Parameters
+    ----------
+    x : np.ndarray
+        x coordinates in meters.
+    y : np.ndarray
+        y coordinates in meters.
+    sig_x : float
+        Transverse rms of the distribution in meters.
+    sig_y : float
+        Vertical rms of the distribution in meters.
+
+    Returns
+    -------
+    tuple
+        Normalized electromagnetic fields Ex/Q, Ey/Q in the units of (V/m/Coulomb).
+    """
+    sig_sqrt = _sqrt_sig(sig_x, sig_y)
+    w1re, w1im = _wofz(x / sig_sqrt, y / sig_sqrt)
+    ex = np.exp(-x * x / (2*sig_x*sig_x) - y * y / (2*sig_y*sig_y))
+    w2re, w2im = _wofz(x * sig_y / (sig_x*sig_sqrt),
+                       y * sig_x / (sig_y*sig_sqrt))
+    denom = 2 * epsilon_0 * np.sqrt(pi) * sig_sqrt
+    return (w1im - ex*w2im) / denom, (w1re - ex*w2re) / denom
+
+
+def efieldn_gauss_round(x, y, sig_x, sig_y):
+    """
+    Computes the electromagnetic field of a round Gaussian distribution.
+
+    Parameters
+    ----------
+    x : np.ndarray
+        x coordinates in meters.
+    y : np.ndarray
+        y coordinates in meters.
+    sig_x : float
+        Transverse rms of the distribution in meters.
+    sig_y : float
+        Vertical rms of the distribution in meters.
+
+    Returns
+    -------
+    tuple
+        Normalized electromagnetic fields Ex/Q, Ey/Q in the units of (V/m/Coulomb).
+    """
+    r_squared = x*x + y*y
+    sig_r = sig_x
+    amplitude = (1 - np.exp(-r_squared /
+                            (2*sig_r*sig_r))) / (2*pi*epsilon_0*r_squared)
+    return x * amplitude, y * amplitude
+
+
+def _efieldn_linearized(x, y, sig_x, sig_y):
+    """
+    Computes linearized electromagnetic field.
+
+    Parameters
+    ----------
+    x : np.ndarray
+        x coordinate in meters.
+    y : np.ndarray
+        y coordinate in meters.
+    sig_x : float
+        Vertical rms of the distribution in meters.
+    sig_y : float
+        Vertical rms of the distribution in meters.
+
+    Returns
+    -------
+    tuple
+        Normalized electromagnetic fields Ex/Q, Ey/Q in the units of (V/m/Coulomb).
+    """
+    a = np.sqrt(2) * sig_x
+    b = np.sqrt(2) * sig_y
+    amplitude = 1 / (pi * epsilon_0 * (a+b))
+    return x / a * amplitude, y / b * amplitude
+
+
+def add_sigma_check(efieldn):
+    """
+    Wrapper for a normalized electromagnetic field function.
+    Adds the following actions before calculating the field:
+    1) Exchange x and y quantities if sig_x < sig_y.
+    2) Apply round beam field formula when sig_x is close to sig_y.
+
+    Parameters
+    ----------
+    efieldn : function
+        Function to calculate normalized electromagnetic field.
+
+    Returns
+    -------
+    function
+        Wrapped function, including round beam and inverted sig_x/sig_y.
+    """
+    sigmas_ratio_threshold = 1e-3
+    absolute_threshold = 1e-10
+
+    @wraps(efieldn)
+    def efieldn_checked(x, y, sig_x, sig_y, *args, **kwargs):
+        tol_kwargs = dict(rtol=sigmas_ratio_threshold, atol=absolute_threshold)
+        if np.allclose(sig_x, sig_y, **tol_kwargs):
+            if np.allclose(sig_y, 0, **tol_kwargs):
+                en_x = en_y = np.zeros(x.shape, dtype=np.float64)
+            else:
+                en_x, en_y = efieldn_gauss_round(x, y, sig_x, sig_y, *args,
+                                                 **kwargs)
+        elif np.all(sig_x < sig_y):
+            en_y, en_x = efieldn(y, x, sig_y, sig_x, *args, **kwargs)
+        else:
+            en_x, en_y = efieldn(x, y, sig_x, sig_y, *args, **kwargs)
+        return en_x, en_y
+
+    return efieldn_checked
+
+
+def get_displaced_efield(efieldn, xr, yr, sig_x, sig_y, mean_x, mean_y):
+    """
+    Compute the charge-normalized electric field components of a two-dimensional Gaussian charge distribution.
+
+    Parameters
+    ----------
+    efieldn : function
+        Calculates electromagnetic field of a given distribution of charges.
+    xr : np.array
+        x coordinates in meters.
+    yr : np.array
+        y coordinates in meters.
+    sig_x : float
+        Horizontal rms size in meters.
+    sig_y : float
+        Vertical rms size in meters.
+    mean_x : float
+        Horizontal mean of the distribution in meters.
+    mean_y : float
+        Vertical mean of the distribution in meters.
+
+    Returns
+    -------
+    tuple
+        Charge-normalized electromagnetic fields with a displaced center of the distribution.
+    """
+    x = xr - mean_x
+    y = yr - mean_y
+    efieldn = add_sigma_check(efieldn)
+    en_x, en_y = efieldn(np.abs(x), np.abs(y), sig_x, sig_y)
+    en_x = np.abs(en_x) * np.sign(x)
+    en_y = np.abs(en_y) * np.sign(y)
+
+    return en_x, en_y
diff --git a/mbtrack2/tracking/spacecharge.py b/mbtrack2/tracking/spacecharge.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb70f6175baa9a4083c31966d2b71f1f43f662d
--- /dev/null
+++ b/mbtrack2/tracking/spacecharge.py
@@ -0,0 +1,100 @@
+"""
+Module for transverse space charge calculations.
+"""
+from scipy.constants import c
+
+from mbtrack2.tracking.element import Element
+from mbtrack2.tracking.particles_electromagnetic_fields import (
+    _efieldn_mit,
+    get_displaced_efield,
+)
+
+
+class TransverseSpaceCharge(Element):
+    """
+    Class representing a transverse space charge element.
+
+    Parameters
+    ----------
+    ring : Synchrotron
+        The synchrotron object representing the particle accelerator ring.
+    interaction_length : float
+        The interaction length of the space charge effect in meters.
+    n_bins : int, optional
+        The number of bins (longitudinal) used for space charge calculations. Default is 100.
+
+    Attributes
+    ----------
+    ratio_threshold : float
+        The ratio numerical threshold for the space charge element to decide if the beam is transversely round.
+    absolute_threshold : float
+        The absolute numerical threshold for the space charge element to decide if the beam is transversely round.
+    ring: Synchrotron
+        Ring object with information about the ring. This class uses ring.E0 and ring.gamma for calculations.
+    efieldn : function
+        The electric field function.
+
+    Methods
+    -------
+    track(bunch)
+        Perform the tracking of the bunch through the space charge element.
+
+    """
+
+    ratio_threshold = 1e-3
+    absolute_threshold = 1e-10
+
+    def __init__(self, ring, interaction_length, n_bins=100):
+        """
+        Initialize the SpaceCharge object.
+
+        Parameters
+        ----------
+        ring : Synchrotron
+            The synchrotron object representing the particle accelerator ring.
+        interaction_length : float
+            The interaction length of the space charge effect in meters.
+        n_bins : int, optional
+            The number of bins (longitudinal) used for space charge calculations. Default is 100.
+        """
+        self.ring = ring
+        self.n_bins = n_bins
+        self.interaction_length = interaction_length
+        self.efieldn = _efieldn_mit
+
+    @Element.parallel
+    def track(self, bunch):
+        """
+        Perform the tracking of the bunch through the space charge element.
+
+        Parameters
+        ----------
+        bunch : Bunch
+            The bunch of particles to be tracked.
+
+        """
+        prefactor = self.interaction_length / (self.ring.E0 *
+                                               self.ring.gamma**2)
+        (bins, sorted_index, profile,
+         center) = bunch.binning(n_bin=self.n_bins)
+        dz = (bins[1] - bins[0]) * c
+        charge_density = bunch.charge_per_mp * profile / dz
+        for bin_index in range(self.n_bins - 1):
+            particle_ids = (bin_index == sorted_index)
+            if len(particle_ids) == 0:
+                continue
+            x = bunch['x'][particle_ids]
+            y = bunch['y'][particle_ids]
+
+            mean_x, std_x = x.mean(), x.std()
+            mean_y, std_y = y.mean(), y.std()
+
+            en_x, en_y = get_displaced_efield(self.efieldn,
+                                              bunch['x'][particle_ids],
+                                              bunch['y'][particle_ids], std_x,
+                                              std_y, mean_x, mean_y)
+
+            kicks_x = prefactor * en_x * charge_density[bin_index]
+            kicks_y = prefactor * en_y * charge_density[bin_index]
+            bunch['xp'][particle_ids] += kicks_x
+            bunch['yp'][particle_ids] += kicks_y
diff --git a/pyproject.toml b/pyproject.toml
index 9d1b011a64e8ad3b256552747640c25461654733..8ac63cbe5d1eda70cb26ce009661bcc3ed3d5f44 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,7 @@ build-backend = "poetry.core.masonry.api"
 [tool.yapf]
 based_on_style = "pep8"
 arithmetic_precedence_indication = true
+blank_line_before_nested_class_or_def = true 
 
 [tool.isort]
 multi_line_output = 3