diff --git a/mbtrack2/tracking/beam_ion_effects.py b/mbtrack2/tracking/beam_ion_effects.py
index 717b55fefa29fa83120199404be083bca89a9195..dec5433a45e1510fc453c58ab04208f6dd0c2964 100644
--- a/mbtrack2/tracking/beam_ion_effects.py
+++ b/mbtrack2/tracking/beam_ion_effects.py
@@ -254,12 +254,12 @@ class IonParticles(Bunch):
     @mp_number.setter
     def mp_number(self, value):
         self._mp_number = int(value)
-        
+
     @property
     def charge(self):
         """Bunch charge in [C]"""
         return self["charge"].sum()
-    
+
     def _mean_weighted(self, coord):
         """
         Return the mean position of alive particles for each coordinates.
@@ -267,7 +267,8 @@ class IonParticles(Bunch):
         if self.charge == 0:
             return np.zeros_like(coord, dtype=float)
         else:
-            mean = [[np.average(self[name], weights=self["charge"])] for name in coord]
+            mean = [[np.average(self[name], weights=self["charge"])]
+                    for name in coord]
             return np.squeeze(np.array(mean))
 
     def _mean_std_weighted(self, coord):
@@ -276,35 +277,40 @@ class IonParticles(Bunch):
         particles for each coordinates.
         """
         if self.charge == 0:
-            return np.zeros_like(coord, dtype=float), np.zeros_like(coord, dtype=float)
+            return np.zeros_like(coord,
+                                 dtype=float), np.zeros_like(coord,
+                                                             dtype=float)
         else:
-            mean = [[np.average(self[name], weights=self["charge"])] for name in coord]
+            mean = [[np.average(self[name], weights=self["charge"])]
+                    for name in coord]
             var = []
             for i, name in enumerate(coord):
-                var.append(np.average((self[name]-mean[i])**2, weights=self["charge"]))
+                var.append(
+                    np.average((self[name] - mean[i])**2,
+                               weights=self["charge"]))
             var = np.array(var)
             return np.squeeze(np.array(mean)), np.squeeze(np.sqrt(var))
-    
+
     @property
     def mean(self):
-        coord = ["x","xp","y","yp"]
+        coord = ["x", "xp", "y", "yp"]
         return self._mean_weighted(coord)
-    
+
     @property
     def mean_weighted(self):
-        coord = ["x","y"]
+        coord = ["x", "y"]
         mean = self._mean_weighted(coord)
         return mean[0], mean[1]
-    
+
     @property
     def std(self):
-        coord = ["x","xp","y","yp"]
+        coord = ["x", "xp", "y", "yp"]
         _, std = self._mean_std_weighted(coord)
         return std
-    
+
     @property
     def mean_std_weighted(self):
-        coord = ["x","y"]
+        coord = ["x", "y"]
         mean, std = self._mean_std_weighted(coord)
         return mean[0], mean[1], std[0], std[1]
 
@@ -456,16 +462,8 @@ class BeamIonElement(Element):
                  ion_field_model,
                  electron_field_model,
                  ion_element_length,
-                 n_steps,
-                 x_radius,
-                 y_radius,
-                 ion_beam_monitor_name=None,
-                 use_ion_phase_space_monitor=False,
                  n_ion_macroparticles_per_bunch=30,
                  generate_method='distribution'):
-        if use_ion_phase_space_monitor:
-            raise NotImplementedError(
-                "Ion phase space monitor is not implemented.")
         self.ring = ring
         self.bunch_spacing = ring.L / ring.h
         self.ion_mass = ion_mass
@@ -479,7 +477,6 @@ class BeamIonElement(Element):
         if not self.generate_method in ["distribution", "samples"]:
             raise ValueError("Wrong generate_method.")
         self.n_ion_macroparticles_per_bunch = n_ion_macroparticles_per_bunch
-        self.ion_beam_monitor_name = ion_beam_monitor_name
         self.ion_beam = IonParticles(
             mp_number=1,
             ion_element_length=self.ion_element_length,
@@ -491,18 +488,13 @@ class BeamIonElement(Element):
         self.ion_beam["tau"] = 0
         self.ion_beam["delta"] = 0
 
-        if self.ion_beam_monitor_name:
-            warnings.warn(
-                'BeamIonMonitor.beam_monitor.close() should be called at the end of tracking',
-                UserWarning,
-                stacklevel=2)
-            self.beam_monitor = IonMonitor(
-                1,
-                int(n_steps / 10),
-                n_steps,
-                file_name=self.ion_beam_monitor_name)
+        self.generate_ions = True
+        self.beam_ion_interaction = True
+        self.ion_drift = True
 
-        self.aperture = IonAperture(X_radius=x_radius, Y_radius=y_radius)
+        # interfaces for apertures and montiors
+        self.apertures = []
+        self.monitors = []
 
     def parallel(track):
         """
@@ -711,16 +703,13 @@ class BeamIonElement(Element):
         )
         new_ion_charge = (electron_bunch.charge *
                           self.ionization_cross_section *
-                          self.residual_gas_density *
-                          self.ion_element_length)
+                          self.residual_gas_density * self.ion_element_length)
         if self.generate_method == 'distribution':
             new_ion_particles.generate_as_a_distribution(
-                electron_bunch=electron_bunch,
-                charge=new_ion_charge)
+                electron_bunch=electron_bunch, charge=new_ion_charge)
         elif self.generate_method == 'samples':
             new_ion_particles.generate_from_random_samples(
-                electron_bunch=electron_bunch,
-                charge=new_ion_charge)
+                electron_bunch=electron_bunch, charge=new_ion_charge)
         self.ion_beam += new_ion_particles
 
     @parallel
@@ -739,15 +728,16 @@ class BeamIonElement(Element):
         else:
             empty_bucket = False
 
-        if not empty_bucket:
+        if not empty_bucket and self.generate_ions:
             self.generate_new_ions(electron_bunch=electron_bunch)
 
-        self.aperture.track(self.ion_beam)
+        for aperture in self.apertures:
+            aperture.track(self.ion_beam)
 
-        if self.ion_beam_monitor_name is not None:
-            self.beam_monitor.track(self.ion_beam)
+        for monitor in self.monitors:
+            monitor.track(self.ion_beam)
 
-        if not empty_bucket:
+        if not empty_bucket and self.beam_ion_interaction:
             prefactor_to_ion_field = -self.ion_beam.charge / (self.ring.E0)
             prefactor_to_electron_field = -electron_bunch.charge * (
                 e / (self.ion_mass * c**2))
@@ -767,4 +757,5 @@ class BeamIonElement(Element):
             self._update_beam_momentum(electron_bunch, new_xp_electrons,
                                        new_yp_electrons)
 
-        self.track_ions_in_a_drift(drift_length=self.bunch_spacing)
+        if self.ion_drift:
+            self.track_ions_in_a_drift(drift_length=self.bunch_spacing)
diff --git a/tests/unit/tracking/test_beam_ion_effects.py b/tests/unit/tracking/test_beam_ion_effects.py
index 8a556efe27e65fa8c77396aa39a139cfe4235122..d9c9fc86b04e20b1a697cd8add28f045ad2abc23 100644
--- a/tests/unit/tracking/test_beam_ion_effects.py
+++ b/tests/unit/tracking/test_beam_ion_effects.py
@@ -207,11 +207,6 @@ def generate_beam_ion(demo_ring):
             ion_field_model="strong",
             electron_field_model="strong",
             ion_element_length=demo_ring.L,
-            n_steps=int(demo_ring.h*10),
-            x_radius=0.1,
-            y_radius=0.1,
-            ion_beam_monitor_name=None,
-            use_ion_phase_space_monitor=False,
             n_ion_macroparticles_per_bunch=30,
             generate_method='samples'):
         
@@ -224,11 +219,6 @@ def generate_beam_ion(demo_ring):
             ion_field_model=ion_field_model,
             electron_field_model=electron_field_model, 
             ion_element_length=ion_element_length,
-            n_steps=n_steps,
-            x_radius=x_radius,
-            y_radius=y_radius,
-            ion_beam_monitor_name=ion_beam_monitor_name,
-            use_ion_phase_space_monitor=use_ion_phase_space_monitor,
             n_ion_macroparticles_per_bunch=n_ion_macroparticles_per_bunch,
             generate_method=generate_method)
         return beam_ion
@@ -296,15 +286,20 @@ class TestBeamIonElement:
         assert np.allclose(beam_ion.ion_beam["y"], initial_y + drift_length)
 
     # Monitor records ion beam data at specified intervals when enabled
-    def test_monitor_recording(self, generate_beam_ion, small_bunch, tmp_path):
-        monitor_file = str(tmp_path / "test_monitor.hdf5")
-        with pytest.warns(UserWarning):
-            beam_ion = generate_beam_ion(ion_beam_monitor_name=monitor_file)
+    def test_monitor_recording(self, 
+                               generate_beam_ion, 
+                               small_bunch, 
+                               generate_ion_monitor,
+                               tmp_path):
+        file_name=tmp_path / "test_monitor.hdf5"
+        monitor = generate_ion_monitor(file_name=file_name)
+        beam_ion = generate_beam_ion()
+        beam_ion.monitors.append(monitor)
     
         beam_ion.track(small_bunch)
     
-        assert os.path.exists(monitor_file)
-        with hp.File(monitor_file, 'r') as f:
+        assert os.path.exists(file_name)
+        with hp.File(file_name, 'r') as f:
             cond = False
             for key in f.keys():
                 if key.startswith('IonData'):
@@ -322,12 +317,14 @@ class TestBeamIonElement:
     # Boundary conditions at aperture edges
     def test_aperture_boundary(self, generate_beam_ion, small_bunch):
         x_radius = 0.001
-        beam_ion = generate_beam_ion(x_radius=x_radius, y_radius=x_radius)
+        aperture = IonAperture(X_radius=x_radius, Y_radius=x_radius)
+        beam_ion = generate_beam_ion()
+        beam_ion.apertures.append(aperture)
     
         beam_ion.generate_new_ions(small_bunch)
     
         beam_ion.ion_beam["x"] = np.ones_like(beam_ion.ion_beam["x"]) * (x_radius * 1.1)
-        beam_ion.aperture.track(beam_ion.ion_beam)
+        beam_ion.track(beam_ion.ion_beam)
     
         assert len(beam_ion.ion_beam["x"]) == 0