diff --git a/mbtrack2/__init__.py b/mbtrack2/__init__.py
index ca33de1b40b2d16600698e8476137be481510c6e..190f0ac8cf98fec2b8260452d42c32619701e3b0 100644
--- a/mbtrack2/__init__.py
+++ b/mbtrack2/__init__.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
-__version__="0.4"
-from mbtrack2.tracking import *
+__version__ = "0.5.0"
 from mbtrack2.impedance import *
 from mbtrack2.instability import *
-from mbtrack2.utilities import *
\ No newline at end of file
+from mbtrack2.tracking import *
+from mbtrack2.utilities import *
diff --git a/mbtrack2/impedance/__init__.py b/mbtrack2/impedance/__init__.py
index 59a243dfaf20826919019e4e329e03bb8b88c494..33714ed8ee9a2feed9aadabebc70daf3e8d971c9 100644
--- a/mbtrack2/impedance/__init__.py
+++ b/mbtrack2/impedance/__init__.py
@@ -1,16 +1,23 @@
 # -*- coding: utf-8 -*-
-from mbtrack2.impedance.resistive_wall import (skin_depth, 
-                                               CircularResistiveWall, 
-                                               Coating)
-from mbtrack2.impedance.resonator import (Resonator, 
-                                          PureInductive, 
-                                          PureResistive)
-from mbtrack2.impedance.tapers import (StupakovRectangularTaper, 
-                                       StupakovCircularTaper)
-from mbtrack2.impedance.wakefield import (ComplexData, 
-                                          Impedance, 
-                                          WakeFunction, 
-                                          WakeField)
+from mbtrack2.impedance.csr import FreeSpaceCSR, ParallelPlatesCSR
 from mbtrack2.impedance.impedance_model import ImpedanceModel
-from mbtrack2.impedance.csr import (FreeSpaceCSR, 
-                                    ParallelPlatesCSR)
\ No newline at end of file
+from mbtrack2.impedance.resistive_wall import (
+    CircularResistiveWall,
+    Coating,
+    skin_depth,
+)
+from mbtrack2.impedance.resonator import (
+    PureInductive,
+    PureResistive,
+    Resonator,
+)
+from mbtrack2.impedance.tapers import (
+    StupakovCircularTaper,
+    StupakovRectangularTaper,
+)
+from mbtrack2.impedance.wakefield import (
+    ComplexData,
+    Impedance,
+    WakeField,
+    WakeFunction,
+)
diff --git a/mbtrack2/impedance/csr.py b/mbtrack2/impedance/csr.py
index 651e5be34911ea356312da10d29a5b25a64d2b70..592a4dafdb3dc7b3ab3aaae22c2c42a39bf19e89 100644
--- a/mbtrack2/impedance/csr.py
+++ b/mbtrack2/impedance/csr.py
@@ -2,11 +2,13 @@
 """
 Define coherent synchrotron radiation (CSR) wakefields in various models.
 """
-import numpy as np
 import mpmath as mp
+import numpy as np
 from scipy.constants import c, mu_0
 from scipy.special import gamma
-from mbtrack2.impedance.wakefield import WakeField, Impedance, WakeFunction
+
+from mbtrack2.impedance.wakefield import Impedance, WakeField, WakeFunction
+
 
 class FreeSpaceCSR(WakeField):
     """
@@ -36,23 +38,24 @@ class FreeSpaceCSR(WakeField):
     Beams 7.5 (2004): 054403.
 
     """
-    
     def __init__(self, time, frequency, length, radius):
         super().__init__()
-        
+
         self.length = length
         self.radius = radius
-        self.Z0 = mu_0*c
+        self.Z0 = mu_0 * c
 
         Zl = self.LongitudinalImpedance(frequency)
         # Wl = self.LongitudinalWakeFunction(time)
-        
-        Zlong = Impedance(variable = frequency, function = Zl, component_type='long')
+
+        Zlong = Impedance(variable=frequency,
+                          function=Zl,
+                          component_type='long')
         # Wlong = WakeFunction(variable = time, function = Wl, component_type="long")
-        
+
         super().append_to_model(Zlong)
         # super().append_to_model(Wlong)
-        
+
     def LongitudinalImpedance(self, frequency):
         """
         Compute the free space steady-state CSR impedance.
@@ -76,14 +79,16 @@ class FreeSpaceCSR(WakeField):
         radiation using mesh." Physical Review Special Topics-Accelerators and 
         Beams 7.5 (2004): 054403.
         """
-        
-        Zl = (self.Z0 * self.length / (2*np.pi) * gamma(2/3) * 
-              (-1j * 2 * np.pi * frequency / (3 * c * self.radius**2) )**(1/3) )
+
+        Zl = (self.Z0 * self.length / (2 * np.pi) * gamma(2 / 3) *
+              (-1j * 2 * np.pi * frequency /
+               (3 * c * self.radius**2))**(1 / 3))
         return Zl
-    
+
     def LongitudinalWakeFunction(self, time):
         raise NotImplementedError
-        
+
+
 class ParallelPlatesCSR(WakeField):
     """
     Perfectly conducting parallel plates steady-state coherent synchrotron 
@@ -114,29 +119,30 @@ class ParallelPlatesCSR(WakeField):
     Beams 7.5 (2004): 054403.
 
     """
-    
     def __init__(self, time, frequency, length, radius, distance):
         super().__init__()
-        
+
         self.length = length
         self.radius = radius
         self.distance = distance
-        self.Z0 = mu_0*c
+        self.Z0 = mu_0 * c
 
         Zl = self.LongitudinalImpedance(frequency)
         # Wl = self.LongitudinalWakeFunction(time)
-        
-        Zlong = Impedance(variable = frequency, function = Zl, component_type='long')
+
+        Zlong = Impedance(variable=frequency,
+                          function=Zl,
+                          component_type='long')
         # Wlong = WakeFunction(variable = time, function = Wl, component_type="long")
-        
+
         super().append_to_model(Zlong)
         # super().append_to_model(Wlong)
-        
+
     @property
     def threshold(self):
         """Shielding threshold in the parallel plates model in [Hz]."""
-        return (3 * c) / (2 * np.pi) * (self.radius / self.distance ** 3) ** 0.5
-        
+        return (3*c) / (2 * np.pi) * (self.radius / self.distance**3)**0.5
+
     def LongitudinalImpedance(self, frequency, tol=1e-5):
         """
         Compute the CSR impedance using the perfectly conducting parallel 
@@ -162,20 +168,21 @@ class ParallelPlatesCSR(WakeField):
         radiation using mesh." Physical Review Special Topics-Accelerators and 
         Beams 7.5 (2004): 054403.
         """
-        
+
         Zl = np.zeros(frequency.shape, dtype=complex)
-        constant = (2 * np.pi * self.Z0* self.length / self.distance 
-                    * (2 / self.radius)**(1/3) )
+        constant = (2 * np.pi * self.Z0 * self.length / self.distance *
+                    (2 / self.radius)**(1 / 3))
         for i, f in enumerate(frequency):
             k = 2 * mp.pi * f / c
-            
-            sum_value = mp.nsum(lambda p: self.sum_func(p, k), [0,mp.inf], 
-                                tol=tol, method='r+s')
-            
-            Zl[i] = constant * (1/k)**(1/3) * complex(sum_value)
-    
+
+            sum_value = mp.nsum(lambda p: self.sum_func(p, k), [0, mp.inf],
+                                tol=tol,
+                                method='r+s')
+
+            Zl[i] = constant * (1 / k)**(1 / 3) * complex(sum_value)
+
         return Zl
-    
+
     def sum_func(self, p, k):
         """
         Utility function for LongitudinalImpedance.
@@ -190,12 +197,13 @@ class ParallelPlatesCSR(WakeField):
         sum_value : mpc
 
         """
-        xp = (2*p + 1)*mp.pi / self.distance * ( self.radius / 2 / k**2 )**(1/3)
+        xp = (2*p + 1) * mp.pi / self.distance * (self.radius / 2 / k**2)**(1 /
+                                                                            3)
         Ai = mp.airyai(xp**2)
         Bi = mp.airybi(xp**2)
-        Aip = mp.airyai(xp**2,1)
-        Bip = mp.airybi(xp**2,1)
-        return Aip*(Aip + 1j*Bip) + xp**2 * Ai * (Ai + 1j*Bi)
-    
+        Aip = mp.airyai(xp**2, 1)
+        Bip = mp.airybi(xp**2, 1)
+        return Aip * (Aip + 1j*Bip) + xp**2 * Ai * (Ai + 1j*Bi)
+
     def LongitudinalWakeFunction(self, time):
         raise NotImplementedError
diff --git a/mbtrack2/impedance/impedance_model.py b/mbtrack2/impedance/impedance_model.py
index cf586387aea1e5dd7d78f23ee23ec373df9eab9a..96efc7fb265554b7d17c631f5731b48ba4c6d19e 100644
--- a/mbtrack2/impedance/impedance_model.py
+++ b/mbtrack2/impedance/impedance_model.py
@@ -2,21 +2,28 @@
 """
 Module where the ImpedanceModel class is defined.
 """
-import pandas as pd
-import numpy as np
-import matplotlib.pyplot as plt
-from matplotlib import colormaps
 import pickle
 from copy import deepcopy
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from matplotlib import colormaps
+from mpl_toolkits.axes_grid1.inset_locator import inset_axes
 from scipy.integrate import trapz
 from scipy.interpolate import interp1d
-from mpl_toolkits.axes_grid1.inset_locator import inset_axes
-from mbtrack2.utilities.misc import (beam_loss_factor, effective_impedance,
-                                     double_sided_impedance)
-from mbtrack2.utilities.spectrum import (beam_spectrum,
-                                         gaussian_bunch_spectrum,
-                                         spectral_density)
+
 from mbtrack2.impedance.wakefield import WakeField, WakeFunction
+from mbtrack2.utilities.misc import (
+    beam_loss_factor,
+    double_sided_impedance,
+    effective_impedance,
+)
+from mbtrack2.utilities.spectrum import (
+    beam_spectrum,
+    gaussian_bunch_spectrum,
+    spectral_density,
+)
 
 
 class ImpedanceModel():
@@ -72,7 +79,6 @@ class ImpedanceModel():
     load(file)
         Load impedance model from file.
     """
-
     def __init__(self, ring):
         self.ring = ring
         self.optics = self.ring.optics
@@ -167,15 +173,14 @@ class ImpedanceModel():
         local_beta = self.ring.optics.local_beta
         for component_name in wake_sum.components:
             comp = getattr(wake_sum, component_name)
-            weight = ((beta[0, :] ** comp.power_x) *
-                      (beta[1, :] ** comp.power_y))
+            weight = ((beta[0, :]**comp.power_x) * (beta[1, :]**comp.power_y))
             if comp.plane == "x":
                 weight = weight.sum() / local_beta[0]
             if comp.plane == "y":
                 weight = weight.sum() / local_beta[1]
             else:
                 weight = weight.sum()
-            setattr(wake_sum, component_name, weight*comp)
+            setattr(wake_sum, component_name, weight * comp)
         return wake_sum
 
     def compute_sum_names(self):
@@ -225,7 +230,10 @@ class ImpedanceModel():
                     except AttributeError:
                         setattr(self.sum, component_name2, comp2)
 
-    def group_attributes(self, string_in_name, names_to_group=None, property_list=['Zlong']):
+    def group_attributes(self,
+                         string_in_name,
+                         names_to_group=None,
+                         property_list=['Zlong']):
         """Groups attributes in the ImpedanceModel based on a given string pattern.
         Args:
             string_in_name (str): The string pattern used to match attribute names for grouping. If names_to_group is given, this is a name of a new attribute instead.
@@ -258,8 +266,8 @@ class ImpedanceModel():
             for prop in property_list:
                 old_values = getattr(getattr(self, string_in_name), prop)
                 new_values = getattr(getattr(self, attr), prop)
-                setattr(getattr(self, string_in_name),
-                        prop, old_values+new_values)
+                setattr(getattr(self, string_in_name), prop,
+                        old_values + new_values)
             self.sum_names.remove(attr)
             delattr(self, attr)
         self.sum_names.append(string_in_name)
@@ -282,8 +290,12 @@ class ImpedanceModel():
         """
         self.__dict__[new_name] = self.__dict__.pop(old_name)
 
-    def plot_area(self, Z_type="Zlong", component="real", sigma=None,
-                  attr_list=None, zoom=False):
+    def plot_area(self,
+                  Z_type="Zlong",
+                  component="real",
+                  sigma=None,
+                  attr_list=None,
+                  zoom=False):
         """
         Plot the contributions of different kind of WakeFields.
 
@@ -306,23 +318,30 @@ class ImpedanceModel():
             attr_list = self.sum_names
 
         # manage legend
-        Ztype_dict = {"Zlong": 0, "Zxdip": 1,
-                      "Zydip": 2, "Zxquad": 3, "Zyquad": 4}
+        Ztype_dict = {
+            "Zlong": 0,
+            "Zxdip": 1,
+            "Zydip": 2,
+            "Zxquad": 3,
+            "Zyquad": 4
+        }
         scale = [1e-3, 1e-6, 1e-6, 1e-6, 1e-6]
-        label_list = [r"$Z_\mathrm{long} \; (\mathrm{k}\Omega)$",
-                      r"$\frac{1}{\beta_0} \sum_{j} \beta_{x,j} Z_{x,j}^\mathrm{Dip} \; (\mathrm{M}\Omega/\mathrm{m})$",
-                      r"$\frac{1}{\beta_0} \sum_{j} \beta_{y,j} Z_{y,j}^\mathrm{Dip} \; (\mathrm{M}\Omega/\mathrm{m})$",
-                      r"$\frac{1}{\beta_0} \sum_{j} \beta_{x,j} Z_{x,j}^\mathrm{Quad} \; (\mathrm{M}\Omega/\mathrm{m})$",
-                      r"$\frac{1}{\beta_0} \sum_{j} \beta_{y,j} Z_{y,j}^\mathrm{Quad} \; (\mathrm{M}\Omega/\mathrm{m})$"]
+        label_list = [
+            r"$Z_\mathrm{long} \; (\mathrm{k}\Omega)$",
+            r"$\frac{1}{\beta_0} \sum_{j} \beta_{x,j} Z_{x,j}^\mathrm{Dip} \; (\mathrm{M}\Omega/\mathrm{m})$",
+            r"$\frac{1}{\beta_0} \sum_{j} \beta_{y,j} Z_{y,j}^\mathrm{Dip} \; (\mathrm{M}\Omega/\mathrm{m})$",
+            r"$\frac{1}{\beta_0} \sum_{j} \beta_{x,j} Z_{x,j}^\mathrm{Quad} \; (\mathrm{M}\Omega/\mathrm{m})$",
+            r"$\frac{1}{\beta_0} \sum_{j} \beta_{y,j} Z_{y,j}^\mathrm{Quad} \; (\mathrm{M}\Omega/\mathrm{m})$"
+        ]
         leg = Ztype_dict[Z_type]
 
         # sort plot by decresing area size
-        area = np.zeros((len(attr_list),))
+        area = np.zeros((len(attr_list), ))
         for index, attr in enumerate(attr_list):
             try:
                 sum_imp = getattr(getattr(self, attr), Z_type)
-                area[index] = trapz(
-                    sum_imp.data[component], sum_imp.data.index)
+                area[index] = trapz(sum_imp.data[component],
+                                    sum_imp.data.index)
             except AttributeError:
                 pass
         sorted_index = area.argsort()[::-1]
@@ -330,7 +349,7 @@ class ImpedanceModel():
         # Init fig
         fig = plt.figure()
         ax = fig.gca()
-        zero_impedance = getattr(self.sum, Z_type)*0
+        zero_impedance = getattr(self.sum, Z_type) * 0
         total_imp = 0
         legend = []
 
@@ -344,9 +363,13 @@ class ImpedanceModel():
             # Set all impedances with common indexes using + zero_impedance
             try:
                 sum_imp = getattr(getattr(self, attr), Z_type) + zero_impedance
-                ax.fill_between(sum_imp.data.index*1e-9, total_imp,
-                                total_imp + sum_imp.data[component]*scale[leg], edgecolor=colorblind[index % 10], color=colorblind[index % 10])
-                total_imp += sum_imp.data[component]*scale[leg]
+                ax.fill_between(sum_imp.data.index * 1e-9,
+                                total_imp,
+                                total_imp +
+                                sum_imp.data[component] * scale[leg],
+                                edgecolor=colorblind[index % 10],
+                                color=colorblind[index % 10])
+                total_imp += sum_imp.data[component] * scale[leg]
                 if attr[:4] == "sum_":
                     legend.append(attr[4:])
                 else:
@@ -356,8 +379,8 @@ class ImpedanceModel():
 
         if sigma is not None:
             spect = spectral_density(zero_impedance.data.index, sigma)
-            spect = spect/spect.max()*total_imp.max()
-            ax.plot(sum_imp.data.index*1e-9, spect, 'r', linewidth=2.5)
+            spect = spect / spect.max() * total_imp.max()
+            ax.plot(sum_imp.data.index * 1e-9, spect, 'r', linewidth=2.5)
 
         ax.legend(legend, loc="upper left", ncol=2)
         ax.set_xlabel("Frequency (GHz)")
@@ -365,10 +388,11 @@ class ImpedanceModel():
         ax.set_title(label_list[leg] + " - " + component + " part")
 
         if zoom is True:
-            in_ax = inset_axes(ax,
-                               width="30%",  # width = 30% of parent_bbox
-                               height=1.5,  # height : 1 inch
-                               loc=1)
+            in_ax = inset_axes(
+                ax,
+                width="30%",  # width = 30% of parent_bbox
+                height=1.5,  # height : 1 inch
+                loc=1)
 
             total_imp = 0
             for index in sorted_index:
@@ -377,9 +401,13 @@ class ImpedanceModel():
                 try:
                     sum_imp = getattr(getattr(self, attr),
                                       Z_type) + zero_impedance
-                    in_ax.fill_between(sum_imp.data.index*1e-3, total_imp,
-                                       total_imp + sum_imp.data[component]*1e-9, edgecolor=colorblind[index % 10], color=colorblind[index % 10])
-                    total_imp += sum_imp.data[component]*1e-9
+                    in_ax.fill_between(sum_imp.data.index * 1e-3,
+                                       total_imp,
+                                       total_imp +
+                                       sum_imp.data[component] * 1e-9,
+                                       edgecolor=colorblind[index % 10],
+                                       color=colorblind[index % 10])
+                    total_imp += sum_imp.data[component] * 1e-9
                 except AttributeError:
                     pass
             in_ax.set_xlim([0, 200])
@@ -388,7 +416,13 @@ class ImpedanceModel():
 
         return fig
 
-    def effective_impedance(self, m, mu, sigma, M, tuneS, xi=None,
+    def effective_impedance(self,
+                            m,
+                            mu,
+                            sigma,
+                            M,
+                            tuneS,
+                            xi=None,
                             mode="Hermite"):
         """
         Compute the longitudinal and transverse effective impedance.
@@ -426,36 +460,38 @@ class ImpedanceModel():
         for i, attr in enumerate(attr_list):
             try:
                 impedance = getattr(getattr(self, attr), "Zlong")
-                eff_array[i, 0] = effective_impedance(self.ring, impedance,
-                                                      m, mu, sigma, M, tuneS,
-                                                      xi, mode)
+                eff_array[i,
+                          0] = effective_impedance(self.ring, impedance, m, mu,
+                                                   sigma, M, tuneS, xi, mode)
             except AttributeError:
                 pass
 
             try:
                 impedance = getattr(getattr(self, attr), "Zxdip")
-                eff_array[i, 1] = effective_impedance(self.ring, impedance,
-                                                      m, mu, sigma, M, tuneS,
-                                                      xi, mode)
+                eff_array[i,
+                          1] = effective_impedance(self.ring, impedance, m, mu,
+                                                   sigma, M, tuneS, xi, mode)
             except AttributeError:
                 pass
 
             try:
                 impedance = getattr(getattr(self, attr), "Zydip")
-                eff_array[i, 2] = effective_impedance(self.ring, impedance,
-                                                      m, mu, sigma, M, tuneS,
-                                                      xi, mode)
+                eff_array[i,
+                          2] = effective_impedance(self.ring, impedance, m, mu,
+                                                   sigma, M, tuneS, xi, mode)
             except AttributeError:
                 pass
 
-        eff_array[:, 0] = eff_array[:, 0]*self.ring.omega0*1e3
-        eff_array[:, 1] = eff_array[:, 1]*1e-3
-        eff_array[:, 2] = eff_array[:, 2]*1e-3
+        eff_array[:, 0] = eff_array[:, 0] * self.ring.omega0 * 1e3
+        eff_array[:, 1] = eff_array[:, 1] * 1e-3
+        eff_array[:, 2] = eff_array[:, 2] * 1e-3
 
-        summary = pd.DataFrame(eff_array, index=attr_list,
-                               columns=["Z/n [mOhm]",
-                                        "sum betax x Zxeff [kOhm]",
-                                        "sum betay x Zyeff [kOhm]"])
+        summary = pd.DataFrame(eff_array,
+                               index=attr_list,
+                               columns=[
+                                   "Z/n [mOhm]", "sum betax x Zxeff [kOhm]",
+                                   "sum betay x Zyeff [kOhm]"
+                               ])
 
         return summary
 
@@ -489,10 +525,10 @@ class ImpedanceModel():
         fmax = self.sum.Zlong.data.index.max()
         fmin = self.sum.Zlong.data.index.min()
 
-        Q = I*self.ring.T0/M
+        Q = I * self.ring.T0 / M
 
         if fmin >= 0:
-            fmin = -1*fmax
+            fmin = -1 * fmax
         f = np.linspace(fmin, fmax, int(n_points))
 
         beam_spect = beam_spectrum(f, M, bunch_spacing, sigma=sigma)
@@ -506,16 +542,20 @@ class ImpedanceModel():
         for i, attr in enumerate(attr_list):
             try:
                 impedance = getattr(getattr(self, attr), "Zlong")
-                loss_array[i, 0] = beam_loss_factor(
-                    impedance, f, beam_spect, self.ring)
-                loss_array[i, 1] = beam_loss_factor(
-                    impedance, f, bunch_spect, self.ring)
+                loss_array[i, 0] = beam_loss_factor(impedance, f, beam_spect,
+                                                    self.ring)
+                loss_array[i, 1] = beam_loss_factor(impedance, f, bunch_spect,
+                                                    self.ring)
             except AttributeError:
                 pass
 
-        loss_array = loss_array*1e-12
-        summary = pd.DataFrame(loss_array, index=attr_list,
-                               columns=["loss factor (beam) [V/pC]", "loss factor (bunch) [V/pC]"])
+        loss_array = loss_array * 1e-12
+        summary = pd.DataFrame(loss_array,
+                               index=attr_list,
+                               columns=[
+                                   "loss factor (beam) [V/pC]",
+                                   "loss factor (bunch) [V/pC]"
+                               ])
 
         summary["P (beam) [W]"] = summary["loss factor (beam) [V/pC]"] * \
             1e12*Q**2/(self.ring.T0)
@@ -524,8 +564,14 @@ class ImpedanceModel():
 
         return summary
 
-    def power_loss_spectrum(self, sigma, M, bunch_spacing, I, n_points=10e6,
-                            max_overlap=False, plot=False):
+    def power_loss_spectrum(self,
+                            sigma,
+                            M,
+                            bunch_spacing,
+                            I,
+                            n_points=10e6,
+                            max_overlap=False,
+                            plot=False):
         """
         Compute the power loss spectrum of the summed longitudinal impedance 
         as in Eq. (4) of [1].
@@ -575,23 +621,23 @@ class ImpedanceModel():
         fmax = impedance.data.index.max()
         fmin = impedance.data.index.min()
 
-        Q = I*self.ring.T0/M
+        Q = I * self.ring.T0 / M
 
         if fmin >= 0:
-            fmin = -1*fmax
+            fmin = -1 * fmax
             double_sided_impedance(impedance)
 
         frequency = np.linspace(fmin, fmax, int(n_points))
         if max_overlap is False:
             spectrum = beam_spectrum(frequency, M, bunch_spacing, sigma)
         else:
-            spectrum = gaussian_bunch_spectrum(frequency, sigma)*M
+            spectrum = gaussian_bunch_spectrum(frequency, sigma) * M
 
-        pmax = np.floor(fmax/self.ring.f0)
-        pmin = np.floor(fmin/self.ring.f0)
+        pmax = np.floor(fmax / self.ring.f0)
+        pmin = np.floor(fmin / self.ring.f0)
 
-        p = np.arange(pmin+1, pmax)
-        pf0 = p*self.ring.f0
+        p = np.arange(pmin + 1, pmax)
+        pf0 = p * self.ring.f0
         ReZ = np.real(impedance(pf0))
         spectral_density = np.abs(spectrum)**2
         # interpolation of the spectrum is needed to avoid problems liked to
@@ -605,7 +651,9 @@ class ImpedanceModel():
             fig, ax = plt.subplots()
             twin = ax.twinx()
             p1, = ax.plot(pf0, ReZ, color="r", label="Re[Z] [Ohm]")
-            p2, = twin.plot(pf0, spect(pf0)*(self.ring.f0 * Q)**2, color="b",
+            p2, = twin.plot(pf0,
+                            spect(pf0) * (self.ring.f0 * Q)**2,
+                            color="b",
                             label="Beam spectrum [a.u.]")
             ax.set_xlabel("Frequency [Hz]")
             ax.set_ylabel("Re[Z] [Ohm]")
@@ -637,11 +685,13 @@ class ImpedanceModel():
         None.
 
         """
-        to_save = {"wakefields": self.wakefields,
-                   "positions": self.positions,
-                   "names": self.names,
-                   "globals": self.globals,
-                   "globals_names": self.globals_names}
+        to_save = {
+            "wakefields": self.wakefields,
+            "positions": self.positions,
+            "names": self.names,
+            "globals": self.globals,
+            "globals_names": self.globals_names
+        }
         with open(file, "wb") as f:
             pickle.dump(to_save, f)
 
diff --git a/mbtrack2/impedance/resistive_wall.py b/mbtrack2/impedance/resistive_wall.py
index 13a736b499fbd94e967bbfce75584b6001d1d8c9..1676ba02bc2354ff925c505c48ab8073296a5591 100644
--- a/mbtrack2/impedance/resistive_wall.py
+++ b/mbtrack2/impedance/resistive_wall.py
@@ -4,11 +4,13 @@ Define resistive wall elements based on the WakeField class.
 """
 
 import numpy as np
-from scipy.constants import mu_0, epsilon_0, c
+from scipy.constants import c, epsilon_0, mu_0
 from scipy.integrate import quad
-from mbtrack2.impedance.wakefield import WakeField, Impedance, WakeFunction
 
-def skin_depth(frequency, rho, mu_r = 1, epsilon_r = 1):
+from mbtrack2.impedance.wakefield import Impedance, WakeField, WakeFunction
+
+
+def skin_depth(frequency, rho, mu_r=1, epsilon_r=1):
     """
     General formula for the skin depth.
     
@@ -29,13 +31,14 @@ def skin_depth(frequency, rho, mu_r = 1, epsilon_r = 1):
         Skin depth in [m].
     
     """
-    
-    delta = (np.sqrt(2*rho/(np.abs(2*np.pi*frequency)*mu_r*mu_0)) * 
-             np.sqrt(np.sqrt(1 + (rho*np.abs(2*np.pi*frequency) * 
-                                  epsilon_r*epsilon_0)**2 ) 
-                    + rho*np.abs(2*np.pi*frequency)*epsilon_r*epsilon_0))
+
+    delta = (np.sqrt(2 * rho / (np.abs(2 * np.pi * frequency) * mu_r * mu_0)) *
+             np.sqrt(
+                 np.sqrt(1 + (rho * np.abs(2 * np.pi * frequency) * epsilon_r *
+                              epsilon_0)**2) +
+                 rho * np.abs(2 * np.pi * frequency) * epsilon_r * epsilon_0))
     return delta
-    
+
 
 class CircularResistiveWall(WakeField):
     """
@@ -74,40 +77,51 @@ class CircularResistiveWall(WakeField):
     Detectors and Associated Equipment 806 (2016): 221-230.
 
     """
-
-    def __init__(self, time, frequency, length, rho, radius, exact=False, 
+    def __init__(self,
+                 time,
+                 frequency,
+                 length,
+                 rho,
+                 radius,
+                 exact=False,
                  atol=1e-20):
         super().__init__()
-        
+
         self.length = length
         self.rho = rho
         self.radius = radius
-        self.Z0 = mu_0*c
-        self.t0 = (2*self.rho*self.radius**2 / self.Z0)**(1/3) / c
-        
-        omega = 2*np.pi*frequency
-        Z1 = length*(1 + np.sign(frequency)*1j)*rho/(
-                2*np.pi*radius*skin_depth(frequency,rho))
-        Z2 = c/omega*length*(1 + np.sign(frequency)*1j)*rho/(
-                np.pi*radius**3*skin_depth(frequency,rho))
-        
+        self.Z0 = mu_0 * c
+        self.t0 = (2 * self.rho * self.radius**2 / self.Z0)**(1 / 3) / c
+
+        omega = 2 * np.pi * frequency
+        Z1 = length * (1 + np.sign(frequency) * 1j) * rho / (
+            2 * np.pi * radius * skin_depth(frequency, rho))
+        Z2 = c / omega * length * (1 + np.sign(frequency) * 1j) * rho / (
+            np.pi * radius**3 * skin_depth(frequency, rho))
+
         Wl = self.LongitudinalWakeFunction(time, exact, atol)
         Wt = self.TransverseWakeFunction(time, exact)
-        
-        Zlong = Impedance(variable = frequency, function = Z1, component_type='long')
-        Zxdip = Impedance(variable = frequency, function = Z2, component_type='xdip')
-        Zydip = Impedance(variable = frequency, function = Z2, component_type='ydip')
-        Wlong = WakeFunction(variable = time, function = Wl, component_type="long")
-        Wxdip = WakeFunction(variable = time, function = Wt, component_type="xdip")
-        Wydip = WakeFunction(variable = time, function = Wt, component_type="ydip")
-        
+
+        Zlong = Impedance(variable=frequency,
+                          function=Z1,
+                          component_type='long')
+        Zxdip = Impedance(variable=frequency,
+                          function=Z2,
+                          component_type='xdip')
+        Zydip = Impedance(variable=frequency,
+                          function=Z2,
+                          component_type='ydip')
+        Wlong = WakeFunction(variable=time, function=Wl, component_type="long")
+        Wxdip = WakeFunction(variable=time, function=Wt, component_type="xdip")
+        Wydip = WakeFunction(variable=time, function=Wt, component_type="ydip")
+
         super().append_to_model(Zlong)
         super().append_to_model(Zxdip)
         super().append_to_model(Zydip)
         super().append_to_model(Wlong)
         super().append_to_model(Wxdip)
         super().append_to_model(Wydip)
-        
+
     def LongitudinalWakeFunction(self, time, exact=False, atol=1e-20):
         """
         Compute the longitudinal wake function of a circular resistive wall 
@@ -144,15 +158,15 @@ class CircularResistiveWall(WakeField):
         wl = np.zeros_like(time)
         idx1 = time < 0
         wl[idx1] = 0
-        if exact==True:
+        if exact == True:
             idx2 = time > 20 * self.t0
-            idx3 = np.logical_not(np.logical_or(idx1,idx2))
+            idx3 = np.logical_not(np.logical_or(idx1, idx2))
             wl[idx3] = self.__LongWakeExact(time[idx3], atol)
         else:
             idx2 = np.logical_not(idx1)
         wl[idx2] = self.__LongWakeApprox(time[idx2])
         return wl
-    
+
     def TransverseWakeFunction(self, time, exact=False):
         """
         Compute the transverse wake function of a circular resistive wall 
@@ -184,61 +198,68 @@ class CircularResistiveWall(WakeField):
         wt = np.zeros_like(time)
         idx1 = time < 0
         wt[idx1] = 0
-        if exact==True:
+        if exact == True:
             idx2 = time > 20 * self.t0
-            idx3 = np.logical_not(np.logical_or(idx1,idx2))
+            idx3 = np.logical_not(np.logical_or(idx1, idx2))
             wt[idx3] = self.__TransWakeExact(time[idx3])
         else:
             idx2 = np.logical_not(idx1)
         wt[idx2] = self.__TransWakeApprox(time[idx2])
         return wt
-    
+
     def __LongWakeExact(self, time, atol):
         wl = np.zeros_like(time)
-        factor = 4*self.Z0*c/(np.pi * self.radius**2) * self.length
+        factor = 4 * self.Z0 * c / (np.pi * self.radius**2) * self.length
         for i, t in enumerate(time):
-            val, err = quad(lambda z:self.__function(t, z), 0, np.inf)
-            wl[i] = factor * ( np.exp(-t/self.t0) / 3 * 
-                              np.cos( np.sqrt(3) * t / self.t0 )  
-                              - np.sqrt(2) / np.pi * val )
+            val, err = quad(lambda z: self.__function(t, z), 0, np.inf)
+            wl[i] = factor * (
+                np.exp(-t / self.t0) / 3 * np.cos(np.sqrt(3) * t / self.t0) -
+                np.sqrt(2) / np.pi * val)
             if np.isclose(0, t, atol=atol):
-                wl[i] = wl[i]/2
+                wl[i] = wl[i] / 2
         return wl
-    
+
     def __TransWakeExact(self, time):
         wt = np.zeros_like(time)
-        factor = ((8 * self.Z0 * c**2 * self.t0) / (np.pi * self.radius**4) * 
+        factor = ((8 * self.Z0 * c**2 * self.t0) / (np.pi * self.radius**4) *
                   self.length)
         for i, t in enumerate(time):
-            val, err = quad(lambda z:self.__function2(t, z), 0, np.inf)
-            wt[i] = factor * ( 1 / 12 * (-1 * np.exp(-t/self.t0) * 
-                                      np.cos( np.sqrt(3) * t / self.t0 ) + 
-                                      np.sqrt(3) * np.exp(-t/self.t0) * 
-                                      np.sin( np.sqrt(3) * t / self.t0 ) ) -
-                                      np.sqrt(2) / np.pi * val )
+            val, err = quad(lambda z: self.__function2(t, z), 0, np.inf)
+            wt[i] = factor * (
+                1 / 12 *
+                (-1 * np.exp(-t / self.t0) * np.cos(np.sqrt(3) * t / self.t0) +
+                 np.sqrt(3) * np.exp(-t / self.t0) *
+                 np.sin(np.sqrt(3) * t / self.t0)) - np.sqrt(2) / np.pi * val)
         return wt
-    
+
     def __LongWakeApprox(self, t):
-        wl = - 1 * ( 1 / (4*np.pi * self.radius) * 
-                    np.sqrt(self.Z0 * self.rho / (c * np.pi) ) /
-                    t ** (3/2) ) * self.length
+        wl = -1 * (1 / (4 * np.pi * self.radius) * np.sqrt(self.Z0 * self.rho /
+                                                           (c * np.pi)) /
+                   t**(3 / 2)) * self.length
         return wl
-    
+
     def __TransWakeApprox(self, t):
         wt = (1 / (np.pi * self.radius**3) *
-              np.sqrt(self.Z0 * c * self.rho / np.pi)
-              / t ** (1/2) * self.length)
+              np.sqrt(self.Z0 * c * self.rho / np.pi) / t**(1 / 2) *
+              self.length)
         return wt
-    
+
     def __function(self, t, x):
-        return ( (x**2 * np.exp(-1* (x**2) * t / self.t0) ) / (x**6 + 8) )
-    
+        return ((x**2 * np.exp(-1 * (x**2) * t / self.t0)) / (x**6 + 8))
+
     def __function2(self, t, x):
-        return ( (-1 * np.exp(-1* (x**2) * t / self.t0) ) / (x**6 + 8) )
-    
+        return ((-1 * np.exp(-1 * (x**2) * t / self.t0)) / (x**6 + 8))
+
+
 class Coating(WakeField):
-    
-    def __init__(self, frequency, length, rho1, rho2, radius, thickness, approx=False):
+    def __init__(self,
+                 frequency,
+                 length,
+                 rho1,
+                 rho2,
+                 radius,
+                 thickness,
+                 approx=False):
         """
         WakeField element for a coated circular beam pipe.
         
@@ -270,24 +291,30 @@ class Coating(WakeField):
 
         """
         super().__init__()
-        
+
         self.length = length
         self.rho1 = rho1
         self.rho2 = rho2
         self.radius = radius
         self.thickness = thickness
-        
+
         Zl = self.LongitudinalImpedance(frequency, approx)
         Zt = self.TransverseImpedance(frequency, approx)
-        
-        Zlong = Impedance(variable = frequency, function = Zl, component_type='long')
-        Zxdip = Impedance(variable = frequency, function = Zt, component_type='xdip')
-        Zydip = Impedance(variable = frequency, function = Zt, component_type='ydip')
-        
+
+        Zlong = Impedance(variable=frequency,
+                          function=Zl,
+                          component_type='long')
+        Zxdip = Impedance(variable=frequency,
+                          function=Zt,
+                          component_type='xdip')
+        Zydip = Impedance(variable=frequency,
+                          function=Zt,
+                          component_type='ydip')
+
         super().append_to_model(Zlong)
         super().append_to_model(Zxdip)
         super().append_to_model(Zydip)
-        
+
     def LongitudinalImpedance(self, f, approx):
         """
         Compute the longitudinal impedance of a coating using Eq. (5), or 
@@ -314,28 +341,29 @@ class Coating(WakeField):
         Physical Review Accelerators and Beams 21.4 (2018): 041001.
 
         """
-        
-        Z0 = mu_0*c
-        factor = Z0*f/(2*c*self.radius)*self.length
+
+        Z0 = mu_0 * c
+        factor = Z0 * f / (2 * c * self.radius) * self.length
         skin1 = skin_depth(f, self.rho1)
         skin2 = skin_depth(f, self.rho2)
-        
+
         if approx == False:
-            alpha = skin1/skin2
-            tanh = np.tanh( (1 + 1j*np.sign(f)) * self.thickness / skin1 )
-            bracket = ( (np.sign(f) + 1j) * skin1 * 
-                       (alpha * tanh + 1) / (alpha + tanh) )
+            alpha = skin1 / skin2
+            tanh = np.tanh((1 + 1j * np.sign(f)) * self.thickness / skin1)
+            bracket = ((np.sign(f) + 1j) * skin1 * (alpha*tanh + 1) /
+                       (alpha+tanh))
         else:
             valid_approx = self.thickness / np.min(skin1)
             if valid_approx < 0.01:
-                print("Approximation is not valid. Returning impedance anyway.")
-            bracket = ( (np.sign(f) + 1j) * skin2 + 2 * 1j * self.thickness * 
-                       (1 - self.rho2/self.rho1) )
-        
+                print(
+                    "Approximation is not valid. Returning impedance anyway.")
+            bracket = ((np.sign(f) + 1j) * skin2 + 2 * 1j * self.thickness *
+                       (1 - self.rho2 / self.rho1))
+
         Zl = factor * bracket
-        
+
         return Zl
-        
+
     def TransverseImpedance(self, f, approx):
         """
         Compute the transverse impedance of a coating using Eq. (6), or 
@@ -362,31 +390,26 @@ class Coating(WakeField):
         Physical Review Accelerators and Beams 21.4 (2018): 041001.
 
         """
-        
-        Z0 = mu_0*c
-        factor = Z0/(2*np.pi*self.radius**3)*self.length
+
+        Z0 = mu_0 * c
+        factor = Z0 / (2 * np.pi * self.radius**3) * self.length
         skin1 = skin_depth(f, self.rho1)
         skin2 = skin_depth(f, self.rho2)
-        
+
         if approx == False:
-            alpha = skin1/skin2
-            tanh = np.tanh( (1 + 1j*np.sign(f)) * self.thickness / skin1 )
-            bracket = ( (1 + 1j*np.sign(f)) * skin1 * 
-                       (alpha * tanh + 1) / (alpha + tanh) )
+            alpha = skin1 / skin2
+            tanh = np.tanh((1 + 1j * np.sign(f)) * self.thickness / skin1)
+            bracket = ((1 + 1j * np.sign(f)) * skin1 * (alpha*tanh + 1) /
+                       (alpha+tanh))
         else:
             valid_approx = self.thickness / np.min(skin1)
             if valid_approx < 0.01:
-                print("Approximation is not valid. Returning impedance anyway.")
-            bracket = ( (1 + 1j*np.sign(f)) * skin2 + 2 * 1j * self.thickness 
-                       * np.sign(f) * (1 - self.rho2/self.rho1) )
-        
+                print(
+                    "Approximation is not valid. Returning impedance anyway.")
+            bracket = ((1 + 1j * np.sign(f)) * skin2 +
+                       2 * 1j * self.thickness * np.sign(f) *
+                       (1 - self.rho2 / self.rho1))
+
         Zt = factor * bracket
-        
+
         return Zt
-        
-        
-        
-        
-        
-        
-        
\ No newline at end of file
diff --git a/mbtrack2/impedance/resonator.py b/mbtrack2/impedance/resonator.py
index 503d106cc40ba0d35f9de87bcd7a1ff22747302f..640a52d135434a55ff70f703e72405e2c81a9873 100644
--- a/mbtrack2/impedance/resonator.py
+++ b/mbtrack2/impedance/resonator.py
@@ -5,8 +5,9 @@ based on the WakeField class.
 """
 
 import numpy as np
-from mbtrack2.impedance.wakefield import (WakeField, Impedance, 
-                                                   WakeFunction)
+
+from mbtrack2.impedance.wakefield import Impedance, WakeField, WakeFunction
+
 
 class Resonator(WakeField):
     def __init__(self, time, frequency, Rs, fr, Q, plane, atol=1e-20):
@@ -48,68 +49,71 @@ class Resonator(WakeField):
             self.plane = [plane]
         elif isinstance(plane, list):
             self.plane = plane
-            
+
         if self.Q >= 0.5:
             self.Q_p = np.sqrt(self.Q**2 - 0.25)
         else:
             self.Q_p = np.sqrt(0.25 - self.Q**2)
-        self.wr_p = (self.wr*self.Q_p)/self.Q
-        
+        self.wr_p = (self.wr * self.Q_p) / self.Q
+
         for dim in self.plane:
             if dim == "long":
-                Zlong = Impedance(variable=frequency, 
-                                function=self.long_impedance(frequency),
-                                component_type="long")
+                Zlong = Impedance(variable=frequency,
+                                  function=self.long_impedance(frequency),
+                                  component_type="long")
                 super().append_to_model(Zlong)
                 Wlong = WakeFunction(variable=time,
-                                    function=self.long_wake_function(time, atol),
-                                    component_type="long")
+                                     function=self.long_wake_function(
+                                         time, atol),
+                                     component_type="long")
                 super().append_to_model(Wlong)
-                
+
             elif dim == "x" or dim == "y":
-                Zdip = Impedance(variable=frequency, 
-                                function=self.transverse_impedance(frequency),
-                                component_type=dim + "dip")
+                Zdip = Impedance(variable=frequency,
+                                 function=self.transverse_impedance(frequency),
+                                 component_type=dim + "dip")
                 super().append_to_model(Zdip)
-                Wdip = WakeFunction(variable=time,
-                                    function=self.transverse_wake_function(time),
-                                    component_type=dim + "dip")
+                Wdip = WakeFunction(
+                    variable=time,
+                    function=self.transverse_wake_function(time),
+                    component_type=dim + "dip")
                 super().append_to_model(Wdip)
             else:
                 raise ValueError("Plane must be: long, x or y")
-        
+
     def long_wake_function(self, t, atol):
         if self.Q >= 0.5:
-            wl = ( (self.wr * self.Rs / self.Q) * 
-                    np.exp(-1* self.wr * t / (2 * self.Q) ) *
-                     (np.cos(self.wr_p * t) - 
-                      np.sin(self.wr_p * t) / (2 * self.Q_p) ) )
+            wl = ((self.wr * self.Rs / self.Q) * np.exp(-1 * self.wr * t /
+                                                        (2 * self.Q)) *
+                  (np.cos(self.wr_p * t) - np.sin(self.wr_p * t) /
+                   (2 * self.Q_p)))
         elif self.Q < 0.5:
-            wl = ( (self.wr * self.Rs / self.Q) * 
-                    np.exp(-1* self.wr * t / (2 * self.Q) ) *
-                     (np.cosh(self.wr_p * t) - 
-                      np.sinh(self.wr_p * t) / (2 * self.Q_p) ) )
+            wl = ((self.wr * self.Rs / self.Q) * np.exp(-1 * self.wr * t /
+                                                        (2 * self.Q)) *
+                  (np.cosh(self.wr_p * t) - np.sinh(self.wr_p * t) /
+                   (2 * self.Q_p)))
         if np.any(np.abs(t) < atol):
-            wl[np.abs(t) < atol] = wl[np.abs(t) < atol]/2
+            wl[np.abs(t) < atol] = wl[np.abs(t) < atol] / 2
         return wl
-                            
+
     def long_impedance(self, f):
-        return self.Rs / (1 + 1j * self.Q * (f/self.fr - self.fr/f))
-    
+        return self.Rs / (1 + 1j * self.Q * (f / self.fr - self.fr / f))
+
     def transverse_impedance(self, f):
-        return self.Rs * self.fr / f / (
-            1 + 1j * self.Q * (f / self.fr - self.fr / f) )
-    
+        return self.Rs * self.fr / f / (1 + 1j * self.Q *
+                                        (f / self.fr - self.fr / f))
+
     def transverse_wake_function(self, t):
         if self.Q >= 0.5:
-            return (self.wr * self.Rs / self.Q_p * 
+            return (self.wr * self.Rs / self.Q_p *
                     np.exp(-1 * t * self.wr / 2 / self.Q_p) *
-                    np.sin(self.wr_p * t) )
+                    np.sin(self.wr_p * t))
         else:
-            return (self.wr * self.Rs / self.Q_p * 
+            return (self.wr * self.Rs / self.Q_p *
                     np.exp(-1 * t * self.wr / 2 / self.Q_p) *
-                    np.sinh(self.wr_p * t) )
-    
+                    np.sinh(self.wr_p * t))
+
+
 class PureInductive(WakeField):
     """
     Pure inductive Wakefield element which computes associated longitudinal 
@@ -127,25 +131,31 @@ class PureInductive(WakeField):
         Maximum frequency used in the impedance. 
     nout, trim : see Impedance.to_wakefunction
     """
-    def __init__(self, L, n_wake=1e6, n_imp=1e6, imp_freq_lim=1e11, nout=None,
+    def __init__(self,
+                 L,
+                 n_wake=1e6,
+                 n_imp=1e6,
+                 imp_freq_lim=1e11,
+                 nout=None,
                  trim=False):
         self.L = L
         self.n_wake = int(n_wake)
         self.n_imp = int(n_imp)
         self.imp_freq_lim = imp_freq_lim
-        
+
         freq = np.linspace(start=1, stop=self.imp_freq_lim, num=self.n_imp)
-        imp = Impedance(variable=freq, 
+        imp = Impedance(variable=freq,
                         function=self.long_impedance(freq),
                         component_type="long")
         super().append_to_model(imp)
-        
+
         wf = imp.to_wakefunction(nout=nout, trim=trim)
         super().append_to_model(wf)
-        
+
     def long_impedance(self, f):
-        return 1j*self.L*f
-    
+        return 1j * self.L * f
+
+
 class PureResistive(WakeField):
     """
     Pure resistive Wakefield element which computes associated longitudinal 
@@ -163,21 +173,26 @@ class PureResistive(WakeField):
         Maximum frequency used in the impedance. 
     nout, trim : see Impedance.to_wakefunction
     """
-    def __init__(self, R, n_wake=1e6, n_imp=1e6, imp_freq_lim=1e11, nout=None,
+    def __init__(self,
+                 R,
+                 n_wake=1e6,
+                 n_imp=1e6,
+                 imp_freq_lim=1e11,
+                 nout=None,
                  trim=False):
         self.R = R
         self.n_wake = int(n_wake)
         self.n_imp = int(n_imp)
         self.imp_freq_lim = imp_freq_lim
-        
+
         freq = np.linspace(start=1, stop=self.imp_freq_lim, num=self.n_imp)
-        imp = Impedance(variable=freq, 
+        imp = Impedance(variable=freq,
                         function=self.long_impedance(freq),
                         component_type="long")
         super().append_to_model(imp)
-        
+
         wf = imp.to_wakefunction(nout=nout, trim=trim)
         super().append_to_model(wf)
-        
+
     def long_impedance(self, f):
-        return self.R
\ No newline at end of file
+        return self.R
diff --git a/mbtrack2/impedance/tapers.py b/mbtrack2/impedance/tapers.py
index e9d575b9e1f7cfbec6140a6248692c3ce81edea0..088fa4b01d2d100ab9287492c9469a4b676570be 100644
--- a/mbtrack2/impedance/tapers.py
+++ b/mbtrack2/impedance/tapers.py
@@ -3,10 +3,12 @@
 Module where taper elements are defined.
 """
 
-from scipy.constants import mu_0, c, pi
 import numpy as np
+from scipy.constants import c, mu_0, pi
 from scipy.integrate import trapz
-from mbtrack2.impedance.wakefield import WakeField, Impedance
+
+from mbtrack2.impedance.wakefield import Impedance, WakeField
+
 
 class StupakovRectangularTaper(WakeField):
     """
@@ -23,11 +25,16 @@ class StupakovRectangularTaper(WakeField):
     length: taper length in [m]
     width : full horizontal width of the taper in [m]
     """
-    
-    def __init__(self, frequency, gap_entrance, gap_exit, length, width, 
-                 m_max=100, n_points=int(1e4)):
+    def __init__(self,
+                 frequency,
+                 gap_entrance,
+                 gap_exit,
+                 length,
+                 width,
+                 m_max=100,
+                 n_points=int(1e4)):
         super().__init__()
-        
+
         self.frequency = frequency
         self.gap_entrance = gap_entrance
         self.gap_exit = gap_exit
@@ -36,104 +43,115 @@ class StupakovRectangularTaper(WakeField):
         self.m_max = m_max
         self.n_points = n_points
 
-        Zlong = Impedance(variable = frequency, function = self.long(), component_type='long')
-        Zxdip = Impedance(variable = frequency, function = self.xdip(), component_type='xdip')
-        Zydip = Impedance(variable = frequency, function = self.ydip(), component_type='ydip')
-        Zxquad = Impedance(variable = frequency, function = -1*self.quad(), component_type='xquad')
-        Zyquad = Impedance(variable = frequency, function = self.quad(), component_type='yquad')
-        
+        Zlong = Impedance(variable=frequency,
+                          function=self.long(),
+                          component_type='long')
+        Zxdip = Impedance(variable=frequency,
+                          function=self.xdip(),
+                          component_type='xdip')
+        Zydip = Impedance(variable=frequency,
+                          function=self.ydip(),
+                          component_type='ydip')
+        Zxquad = Impedance(variable=frequency,
+                           function=-1 * self.quad(),
+                           component_type='xquad')
+        Zyquad = Impedance(variable=frequency,
+                           function=self.quad(),
+                           component_type='yquad')
+
         super().append_to_model(Zlong)
         super().append_to_model(Zxdip)
         super().append_to_model(Zydip)
         super().append_to_model(Zxquad)
         super().append_to_model(Zyquad)
-        
+
     @property
     def gap_prime(self):
-        return (self.gap_entrance-self.gap_exit)/self.length
-    
+        return (self.gap_entrance - self.gap_exit) / self.length
+
     @property
     def angle(self):
-        return np.arctan((self.gap_entrance/2 - self.gap_exit/2)/self.length)
-    
+        return np.arctan(
+            (self.gap_entrance / 2 - self.gap_exit / 2) / self.length)
+
     @property
     def Z0(self):
-        return mu_0*c
+        return mu_0 * c
 
     def long(self, frequency=None):
-        
+
         if frequency is None:
             frequency = self.frequency
-        
+
         def F(x, m_max):
             m = np.arange(0, m_max)
-            phi = np.outer(pi*x/2, 2*m+1)
-            val = 1/(2*m+1)/(np.cosh(phi)**2)*np.tanh(phi)
+            phi = np.outer(pi * x / 2, 2*m + 1)
+            val = 1 / (2*m + 1) / (np.cosh(phi)**2) * np.tanh(phi)
             return val.sum(1)
-    
+
         z = np.linspace(0, self.length, self.n_points)
         g = np.linspace(self.gap_entrance, self.gap_exit, self.n_points)
-        
-        to_integrate = self.gap_prime**2*F(g/self.width, self.m_max)
-        integral = trapz(to_integrate,x=z)
-        
-        return -1j*frequency*self.Z0/(2*c)*integral
-    
+
+        to_integrate = self.gap_prime**2 * F(g / self.width, self.m_max)
+        integral = trapz(to_integrate, x=z)
+
+        return -1j * frequency * self.Z0 / (2*c) * integral
+
     def Z_over_n(self, f0):
-        return np.imag(self.long(1))*f0
+        return np.imag(self.long(1)) * f0
 
     def ydip(self):
-        
         def G1(x, m_max):
             m = np.arange(0, m_max)
-            phi = np.outer(pi*x/2, 2*m+1)
-            val = (2*m+1)/(np.sinh(phi)**2)/np.tanh(phi)
-            val = x[:,None]**3*val
+            phi = np.outer(pi * x / 2, 2*m + 1)
+            val = (2*m + 1) / (np.sinh(phi)**2) / np.tanh(phi)
+            val = x[:, None]**3 * val
             return val.sum(1)
-        
+
         z = np.linspace(0, self.length, self.n_points)
         g = np.linspace(self.gap_entrance, self.gap_exit, self.n_points)
-        
-        to_integrate = self.gap_prime**2/(g**3)*G1(g/self.width, self.m_max)
+
+        to_integrate = self.gap_prime**2 / (g**3) * G1(g / self.width,
+                                                       self.m_max)
         integral = trapz(to_integrate, x=z)
-        
-        return -1j*pi*self.width*self.Z0/4*integral
-    
+
+        return -1j * pi * self.width * self.Z0 / 4 * integral
+
     def xdip(self):
-        
         def G3(x, m_max):
-            m = np.arange(0,m_max)
-            phi = np.outer(pi*x, m)
-            val = 2*m/(np.cosh(phi)**2)*np.tanh(phi)
-            val = x[:,None]**2*val
+            m = np.arange(0, m_max)
+            phi = np.outer(pi * x, m)
+            val = 2 * m / (np.cosh(phi)**2) * np.tanh(phi)
+            val = x[:, None]**2 * val
             return val.sum(1)
-        
+
         z = np.linspace(0, self.length, self.n_points)
         g = np.linspace(self.gap_entrance, self.gap_exit, self.n_points)
-        
-        to_integrate = self.gap_prime**2/(g**2)*G3(g/self.width, self.m_max)
+
+        to_integrate = self.gap_prime**2 / (g**2) * G3(g / self.width,
+                                                       self.m_max)
         integral = trapz(to_integrate, x=z)
-        
-        return -1j*pi*self.Z0/4*integral
-    
-    
+
+        return -1j * pi * self.Z0 / 4 * integral
+
     def quad(self):
-        
         def G2(x, m_max):
             m = np.arange(0, m_max)
-            phi = np.outer(pi*x/2, 2*m+1)
-            val = (2*m+1)/(np.cosh(phi)**2)*np.tanh(phi)
-            val = x[:,None]**2*val
+            phi = np.outer(pi * x / 2, 2*m + 1)
+            val = (2*m + 1) / (np.cosh(phi)**2) * np.tanh(phi)
+            val = x[:, None]**2 * val
             return val.sum(1)
-    
+
         z = np.linspace(0, self.length, self.n_points)
         g = np.linspace(self.gap_entrance, self.gap_exit, self.n_points)
-        
-        to_integrate = self.gap_prime**2/(g**2)*G2(g/self.width, self.m_max)
+
+        to_integrate = self.gap_prime**2 / (g**2) * G2(g / self.width,
+                                                       self.m_max)
         integral = trapz(to_integrate, x=z)
-        
-        return -1j*pi*self.Z0/4*integral
-    
+
+        return -1j * pi * self.Z0 / 4 * integral
+
+
 class StupakovCircularTaper(WakeField):
     """
     Circular taper WakeField element, using the low frequency 
@@ -148,11 +166,15 @@ class StupakovCircularTaper(WakeField):
     radius_exit : radius at taper exit in [m]
     length : taper length in [m]
     """
-    
-    def __init__(self, frequency, radius_entrance, radius_exit, length,
-                 m_max=100, n_points=int(1e4)):
+    def __init__(self,
+                 frequency,
+                 radius_entrance,
+                 radius_exit,
+                 length,
+                 m_max=100,
+                 n_points=int(1e4)):
         super().__init__()
-        
+
         self.frequency = frequency
         self.radius_entrance = radius_entrance
         self.radius_exit = radius_exit
@@ -160,51 +182,57 @@ class StupakovCircularTaper(WakeField):
         self.m_max = m_max
         self.n_points = n_points
 
-        Zlong = Impedance(variable = frequency, function = self.long(), component_type='long')
-        Zxdip = Impedance(variable = frequency, function = self.dip(), component_type='xdip')
-        Zydip = Impedance(variable = frequency, function = self.dip(), component_type='ydip')
-        
+        Zlong = Impedance(variable=frequency,
+                          function=self.long(),
+                          component_type='long')
+        Zxdip = Impedance(variable=frequency,
+                          function=self.dip(),
+                          component_type='xdip')
+        Zydip = Impedance(variable=frequency,
+                          function=self.dip(),
+                          component_type='ydip')
+
         super().append_to_model(Zlong)
         super().append_to_model(Zxdip)
         super().append_to_model(Zydip)
-        
+
     @property
     def angle(self):
-        return np.arctan((self.radius_entrance-self.radius_exit)/self.length)
-    
+        return np.arctan(
+            (self.radius_entrance - self.radius_exit) / self.length)
+
     @property
     def radius_prime(self):
-        return (self.radius_entrance-self.radius_exit)/self.length
-    
+        return (self.radius_entrance - self.radius_exit) / self.length
+
     @property
     def Z0(self):
-        return mu_0*c
+        return mu_0 * c
 
     def long(self, frequency=None):
-        
+
         if frequency is None:
             frequency = self.frequency
-        
-        return (self.Z0/(2*pi)*np.log(self.radius_entrance/self.radius_exit) 
-                - 1j*self.Z0*frequency/(2*c)*self.radius_prime**2*self.length)
-    
+
+        return (self.Z0 /
+                (2*pi) * np.log(self.radius_entrance / self.radius_exit) -
+                1j * self.Z0 * frequency /
+                (2*c) * self.radius_prime**2 * self.length)
+
     def Z_over_n(self, f0):
-        return np.imag(self.long(1))*f0
+        return np.imag(self.long(1)) * f0
 
     def dip(self, frequency=None):
-        
+
         if frequency is None:
             frequency = self.frequency
-        
+
         z = np.linspace(0, self.length, self.n_points)
         r = np.linspace(self.radius_entrance, self.radius_exit, self.n_points)
-        
-        to_integrate = self.radius_prime**2/(r**2)
-        integral = trapz(to_integrate, x=z)
-        
-        return (self.Z0*c/(4*pi**2*frequency)*(1/(self.radius_exit**2) - 
-               1/(self.radius_entrance**2))  - 1j*self.Z0/(2*pi)*integral)
-
-
 
+        to_integrate = self.radius_prime**2 / (r**2)
+        integral = trapz(to_integrate, x=z)
 
+        return (self.Z0 * c / (4 * pi**2 * frequency) *
+                (1 / (self.radius_exit**2) - 1 /
+                 (self.radius_entrance**2)) - 1j * self.Z0 / (2*pi) * integral)
diff --git a/mbtrack2/impedance/wakefield.py b/mbtrack2/impedance/wakefield.py
index 0976f0c59983bd2297f8eeb23080cf82d43c4f2c..6df7cfc91a742b5105abab22a016d680636a22fa 100644
--- a/mbtrack2/impedance/wakefield.py
+++ b/mbtrack2/impedance/wakefield.py
@@ -4,15 +4,17 @@ This module defines general classes to describes wakefields, impedances and
 wake functions.
 """
 
-import warnings
 import pickle
-import pandas as pd
+import warnings
+from copy import deepcopy
+
 import numpy as np
+import pandas as pd
 import scipy as sc
-from copy import deepcopy
-from scipy.interpolate import interp1d
-from scipy.integrate import trapz
 from scipy.constants import c
+from scipy.integrate import trapz
+from scipy.interpolate import interp1d
+
 
 class ComplexData:
     """
@@ -27,15 +29,21 @@ class ComplexData:
     function : list or numpy array of comp lex numbers
         contains the values taken by the complex function
     """
-
-    def __init__(self, variable=np.array([-1e15, 1e15]),
+    def __init__(self,
+                 variable=np.array([-1e15, 1e15]),
                  function=np.array([0, 0])):
-        self.data = pd.DataFrame({'real': np.real(function),
-                                  'imag': np.imag(function)},
-                                 index=variable)
+        self.data = pd.DataFrame(
+            {
+                'real': np.real(function),
+                'imag': np.imag(function)
+            },
+            index=variable)
         self.data.index.name = 'variable'
 
-    def add(self, structure_to_add, method='zero', interp_kind='cubic', 
+    def add(self,
+            structure_to_add,
+            method='zero',
+            interp_kind='cubic',
             index_name="variable"):
         """
         Method to add two structures. If the data don't have the same length,
@@ -70,20 +78,23 @@ class ComplexData:
         # from the two input data
 
         if isinstance(structure_to_add, (int, float, complex)):
-            structure_to_add = ComplexData(variable=self.data.index,
-                                           function=(structure_to_add * 
-                                                     np.ones(len(self.data.index))))
-                                
-        data_to_concat = structure_to_add.data.index.to_frame().set_index(index_name)
-        
+            structure_to_add = ComplexData(
+                variable=self.data.index,
+                function=(structure_to_add * np.ones(len(self.data.index))))
+
+        data_to_concat = structure_to_add.data.index.to_frame().set_index(
+            index_name)
+
         initial_data = pd.concat([self.data, data_to_concat], sort=True)
-        initial_data = initial_data[~initial_data.index.duplicated(keep='first')]
+        initial_data = initial_data[~initial_data.index.duplicated(
+            keep='first')]
         initial_data = initial_data.sort_index()
 
-        data_to_add = pd.concat(
-                        [structure_to_add.data,
-                         self.data.index.to_frame().set_index(index_name)],
-                        sort=True)
+        data_to_add = pd.concat([
+            structure_to_add.data,
+            self.data.index.to_frame().set_index(index_name)
+        ],
+                                sort=True)
         data_to_add = data_to_add[~data_to_add.index.duplicated(keep='first')]
         data_to_add = data_to_add.sort_index()
 
@@ -111,7 +122,7 @@ class ComplexData:
         if method == 'extrapolate':
             print('Not there yet')
             return self
-        
+
         if method == 'zero':
             max_variable = min(structure_to_add.data.index.max(),
                                self.data.index.max())
@@ -121,15 +132,17 @@ class ComplexData:
 
             mask = ((initial_data.index <= max_variable)
                     & (initial_data.index >= min_variable))
-            initial_data[mask] = initial_data[mask].interpolate(method=interp_kind)
+            initial_data[mask] = initial_data[mask].interpolate(
+                method=interp_kind)
 
             mask = ((data_to_add.index <= max_variable)
                     & (data_to_add.index >= min_variable))
-            data_to_add[mask] = data_to_add[mask].interpolate(method=interp_kind)
-            
+            data_to_add[mask] = data_to_add[mask].interpolate(
+                method=interp_kind)
+
             initial_data.replace(to_replace=np.nan, value=0, inplace=True)
             data_to_add.replace(to_replace=np.nan, value=0, inplace=True)
-            
+
             result_structure = ComplexData()
             result_structure.data = initial_data + data_to_add
             return result_structure
@@ -159,7 +172,7 @@ class ComplexData:
 
     def __rmul__(self, factor):
         return self.multiply(factor)
-    
+
     def __call__(self, values, interp_kind="cubic"):
         """
         Interpolation of the data by calling the class to have a function-like
@@ -177,35 +190,38 @@ class ComplexData:
         numpy array 
             Contains the interpolated data.
         """
-        real_func = interp1d(x = self.data.index, 
-                             y = self.data["real"], kind=interp_kind)
-        imag_func = interp1d(x = self.data.index, 
-                             y = self.data["imag"], kind=interp_kind)
-        return real_func(values) + 1j*imag_func(values)
-    
+        real_func = interp1d(x=self.data.index,
+                             y=self.data["real"],
+                             kind=interp_kind)
+        imag_func = interp1d(x=self.data.index,
+                             y=self.data["imag"],
+                             kind=interp_kind)
+        return real_func(values) + 1j * imag_func(values)
+
     def __repr__(self):
         """Return representation of data"""
         return f'{(self.data)!r}'
-    
+
     def initialize_coefficient(self):
         """
         Define the impedance coefficients and the plane of the impedance that
         are presents in attributes of the class.
         """
         table = self.name_and_coefficients_table()
-        
+
         try:
             component_coefficients = table[self.component_type].to_dict()
         except KeyError:
-            print('Component type {} does not exist'.format(self.component_type))
+            print('Component type {} does not exist'.format(
+                self.component_type))
             raise
-        
+
         self.a = component_coefficients["a"]
         self.b = component_coefficients["b"]
         self.c = component_coefficients["c"]
         self.d = component_coefficients["d"]
         self.plane = component_coefficients["plane"]
-                    
+
     def name_and_coefficients_table(self):
         """
         Return a table associating the human readbale names of an impedance
@@ -213,44 +229,111 @@ class ComplexData:
         """
 
         component_dict = {
-            'long': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'z'},
-            'xcst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
-            'ycst': {'a': 0, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
-            'xdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'x'},
-            'ydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'y'},
-            'xydip': {'a': 0, 'b': 1, 'c': 0, 'd': 0, 'plane': 'x'},
-            'yxdip': {'a': 1, 'b': 0, 'c': 0, 'd': 0, 'plane': 'y'},
-            'xquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'x'},
-            'yquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'y'},
-            'xyquad': {'a': 0, 'b': 0, 'c': 0, 'd': 1, 'plane': 'x'},
-            'yxquad': {'a': 0, 'b': 0, 'c': 1, 'd': 0, 'plane': 'y'},
-            }
+            'long': {
+                'a': 0,
+                'b': 0,
+                'c': 0,
+                'd': 0,
+                'plane': 'z'
+            },
+            'xcst': {
+                'a': 0,
+                'b': 0,
+                'c': 0,
+                'd': 0,
+                'plane': 'x'
+            },
+            'ycst': {
+                'a': 0,
+                'b': 0,
+                'c': 0,
+                'd': 0,
+                'plane': 'y'
+            },
+            'xdip': {
+                'a': 1,
+                'b': 0,
+                'c': 0,
+                'd': 0,
+                'plane': 'x'
+            },
+            'ydip': {
+                'a': 0,
+                'b': 1,
+                'c': 0,
+                'd': 0,
+                'plane': 'y'
+            },
+            'xydip': {
+                'a': 0,
+                'b': 1,
+                'c': 0,
+                'd': 0,
+                'plane': 'x'
+            },
+            'yxdip': {
+                'a': 1,
+                'b': 0,
+                'c': 0,
+                'd': 0,
+                'plane': 'y'
+            },
+            'xquad': {
+                'a': 0,
+                'b': 0,
+                'c': 1,
+                'd': 0,
+                'plane': 'x'
+            },
+            'yquad': {
+                'a': 0,
+                'b': 0,
+                'c': 0,
+                'd': 1,
+                'plane': 'y'
+            },
+            'xyquad': {
+                'a': 0,
+                'b': 0,
+                'c': 0,
+                'd': 1,
+                'plane': 'x'
+            },
+            'yxquad': {
+                'a': 0,
+                'b': 0,
+                'c': 1,
+                'd': 0,
+                'plane': 'y'
+            },
+        }
 
         return pd.DataFrame(component_dict)
-    
+
     @property
     def power_x(self):
-        power_x = self.a/2 + self.c/2.
+        power_x = self.a / 2 + self.c / 2.
         if self.plane == 'x':
-            power_x += 1./2.
+            power_x += 1. / 2.
         return power_x
 
     @property
     def power_y(self):
-        power_y = self.b/2. + self.d/2.
+        power_y = self.b / 2. + self.d / 2.
         if self.plane == 'y':
-            power_y += 1./2.
+            power_y += 1. / 2.
         return power_y
-    
+
     @property
     def component_type(self):
         return self._component_type
-    
+
     @component_type.setter
     def component_type(self, value):
         self._component_type = value
         self.initialize_coefficient()
-    
+
+
 class WakeFunction(ComplexData):
     """
     Define a WakeFunction object based on a ComplexData object.
@@ -280,31 +363,33 @@ class WakeFunction(ComplexData):
     loss_factor(sigma)
         Compute the loss factor or the kick factor for a Gaussian bunch.
     """
-
     def __init__(self,
                  variable=np.array([-1e15, 1e15]),
-                 function=np.array([0, 0]), component_type='long'):
+                 function=np.array([0, 0]),
+                 component_type='long'):
         super().__init__(variable, function)
         self._component_type = component_type
         self.data.index.name = "time [s]"
         self.initialize_coefficient()
-        
+
     def __repr__(self):
         """Return representation of data"""
         return f'WakeFunction of component_type {self.component_type}:\n {(self.data)!r}'
-        
+
     def add(self, structure_to_add, method='zero'):
         """
         Method to add two WakeFunction objects. The two structures are
         compared so that the addition is self-consistent.
         """
- 
+
         if isinstance(structure_to_add, (int, float, complex)):
-            result = super().add(structure_to_add, method=method,
+            result = super().add(structure_to_add,
+                                 method=method,
                                  index_name="time [s]")
         elif isinstance(structure_to_add, WakeFunction):
             if (self.component_type == structure_to_add.component_type):
-                result = super().add(structure_to_add, method=method,
+                result = super().add(structure_to_add,
+                                     method=method,
                                      index_name="time [s]")
             else:
                 warnings.warn(('The two WakeFunction objects do not have the '
@@ -312,38 +397,36 @@ class WakeFunction(ComplexData):
                                'Returning initial WakeFunction object.'),
                               UserWarning)
                 result = self
-               
-        wake_to_return = WakeFunction(
-                                result.data.index,
-                                result.data.real.values,
-                                self.component_type)   
+
+        wake_to_return = WakeFunction(result.data.index,
+                                      result.data.real.values,
+                                      self.component_type)
         return wake_to_return
- 
+
     def __radd__(self, structure_to_add):
         return self.add(structure_to_add)
- 
+
     def __add__(self, structure_to_add):
         return self.add(structure_to_add)
- 
+
     def multiply(self, factor):
         """
         Multiply a WakeFunction object with a float or an int.
         If the multiplication is done with something else, throw a warning.
         """
         result = super().multiply(factor)
-        wake_to_return = WakeFunction(
-                                result.data.index,
-                                result.data.real.values,
-                                self.component_type)   
+        wake_to_return = WakeFunction(result.data.index,
+                                      result.data.real.values,
+                                      self.component_type)
         return wake_to_return
- 
+
     def __mul__(self, factor):
         return self.multiply(factor)
- 
+
     def __rmul__(self, factor):
         return self.multiply(factor)
-        
-    def to_impedance(self, freq_lim, nout=None, sigma=None, mu=None):   
+
+    def to_impedance(self, freq_lim, nout=None, sigma=None, mu=None):
         """
         Return an Impedance object from the WakeFunction data.
         The WakeFunction data is assumed to be sampled equally.
@@ -372,40 +455,42 @@ class WakeFunction(ComplexData):
         """
         tau = np.array(self.data.index)
         wp = np.array(self.data["real"])
-        
+
         if nout is None:
             nout = len(tau)
         else:
             nout = int(nout)
-        
+
         # FT of wake potential
         sampling = tau[1] - tau[0]
         freq = sc.fft.rfftfreq(nout, sampling)
-        dtau = (tau[-1]-tau[0])/len(tau)
+        dtau = (tau[-1] - tau[0]) / len(tau)
         dft_wake = sc.fft.rfft(wp, n=nout, axis=0) * dtau
-        
+
         i_limit = freq < freq_lim
         freq_trun = freq[i_limit]
         dft_wake_trun = dft_wake[i_limit]
-        
+
         if (sigma is not None) and (mu is not None):
             dft_rho_trun = np.exp(-0.5*(sigma*2*np.pi*freq_trun)**2 + \
                                       1j*mu*2*np.pi*freq_trun)
         else:
             dft_rho_trun = -1
-            
+
         if self.component_type == "long":
-            imp = dft_wake_trun/dft_rho_trun*-1
-        elif (self.component_type == "xdip") or (self.component_type == "ydip"):
-            imp = dft_wake_trun/dft_rho_trun*-1j
+            imp = dft_wake_trun / dft_rho_trun * -1
+        elif (self.component_type == "xdip") or (self.component_type
+                                                 == "ydip"):
+            imp = dft_wake_trun / dft_rho_trun * -1j
         else:
             raise NotImplementedError(self.component_type + " is not correct.")
-            
-        imp = Impedance(variable=freq_trun, function=imp, 
-                         component_type=self.component_type)
-        
+
+        imp = Impedance(variable=freq_trun,
+                        function=imp,
+                        component_type=self.component_type)
+
         return imp
-    
+
     def deconvolution(self, freq_lim, sigma, mu, nout=None):
         """
         Compute a wake function from wake potential data.
@@ -437,7 +522,7 @@ class WakeFunction(ComplexData):
         imp = self.to_impedance(freq_lim, sigma=sigma, mu=mu)
         wf = imp.to_wakefunction(nout=nout)
         return wf
-        
+
     def plot(self):
         """
         Plot the wake function data.
@@ -456,7 +541,7 @@ class WakeFunction(ComplexData):
         label = r"$W_{" + self.component_type + r"}$" + unit
         ax.set_ylabel(label)
         return ax
-    
+
     def loss_factor(self, sigma):
         """
         Compute the loss factor or the kick factor for a Gaussian bunch. 
@@ -480,11 +565,13 @@ class WakeFunction(ComplexData):
         
         """
         time = self.data.index
-        S = 1/(2*np.sqrt(np.pi)*sigma)*np.exp(-1*time**2/(4*sigma**2))
-        kloss = trapz(x = time, y = self.data["real"]*S)
+        S = 1 / (2 * np.sqrt(np.pi) * sigma) * np.exp(-1 * time**2 /
+                                                      (4 * sigma**2))
+        kloss = trapz(x=time, y=self.data["real"] * S)
 
         return kloss
-        
+
+
 class Impedance(ComplexData):
     """
     Define an Impedance object based on a ComplexData object.
@@ -512,15 +599,15 @@ class Impedance(ComplexData):
     plot()
         Plot the impedance data.
     """
-
     def __init__(self,
                  variable=np.array([-1e15, 1e15]),
-                 function=np.array([0, 0]), component_type='long'):
+                 function=np.array([0, 0]),
+                 component_type='long'):
         super().__init__(variable, function)
         self._component_type = component_type
         self.data.index.name = 'frequency [Hz]'
         self.initialize_coefficient()
-        
+
     def __repr__(self):
         """Return representation of data"""
         return f'Impedance of component_type {self.component_type}:\n {(self.data)!r}'
@@ -533,12 +620,14 @@ class Impedance(ComplexData):
         """
 
         if isinstance(structure_to_add, (int, float, complex)):
-            result = super().add(structure_to_add, method=method, 
+            result = super().add(structure_to_add,
+                                 method=method,
                                  index_name="frequency [Hz]")
         elif isinstance(structure_to_add, Impedance):
             if (self.component_type == structure_to_add.component_type):
-                weight = (beta_x ** self.power_x) * (beta_y ** self.power_y)
-                result = super().add(structure_to_add * weight, method=method, 
+                weight = (beta_x**self.power_x) * (beta_y**self.power_y)
+                result = super().add(structure_to_add * weight,
+                                     method=method,
                                      index_name="frequency [Hz]")
             else:
                 warnings.warn(('The two Impedance objects do not have the '
@@ -546,11 +635,11 @@ class Impedance(ComplexData):
                                'Returning initial Impedance object.'),
                               UserWarning)
                 result = self
-                
+
         impedance_to_return = Impedance(
-                                result.data.index,
-                                result.data.real.values + 1j*result.data.imag.values,
-                                self.component_type)    
+            result.data.index,
+            result.data.real.values + 1j * result.data.imag.values,
+            self.component_type)
         return impedance_to_return
 
     def __radd__(self, structure_to_add):
@@ -566,9 +655,9 @@ class Impedance(ComplexData):
         """
         result = super().multiply(factor)
         impedance_to_return = Impedance(
-                                result.data.index,
-                                result.data.real.values + 1j*result.data.imag.values,
-                                self.component_type)    
+            result.data.index,
+            result.data.real.values + 1j * result.data.imag.values,
+            self.component_type)
         return impedance_to_return
 
     def __mul__(self, factor):
@@ -576,7 +665,7 @@ class Impedance(ComplexData):
 
     def __rmul__(self, factor):
         return self.multiply(factor)
-    
+
     def loss_factor(self, sigma):
         """
         Compute the loss factor or the kick factor for a Gaussian bunch. 
@@ -597,28 +686,29 @@ class Impedance(ComplexData):
         ----------
         [1] : Handbook of accelerator physics and engineering, 3rd printing.
         """
-        
+
         positive_index = self.data.index > 0
         frequency = self.data.index[positive_index]
-        
+
         # Import here to avoid circular import
         from mbtrack2.utilities import spectral_density
         sd = spectral_density(frequency, sigma, m=0)
-        
-        if(self.component_type == "long"):
-            kloss = trapz(x = frequency, 
-                          y = 2*self.data["real"][positive_index]*sd)
-        elif(self.component_type == "xdip" or self.component_type == "ydip"):
-            kloss = trapz(x = frequency, 
-                          y = 2*self.data["imag"][positive_index]*sd)
-        elif(self.component_type == "xquad" or self.component_type == "yquad"):
-            kloss = trapz(x = frequency, 
-                          y = 2*self.data["imag"][positive_index]*sd)
+
+        if (self.component_type == "long"):
+            kloss = trapz(x=frequency,
+                          y=2 * self.data["real"][positive_index] * sd)
+        elif (self.component_type == "xdip" or self.component_type == "ydip"):
+            kloss = trapz(x=frequency,
+                          y=2 * self.data["imag"][positive_index] * sd)
+        elif (self.component_type == "xquad"
+              or self.component_type == "yquad"):
+            kloss = trapz(x=frequency,
+                          y=2 * self.data["imag"][positive_index] * sd)
         else:
             raise TypeError("Impedance type not recognized.")
 
         return kloss
-           
+
     def to_wakefunction(self, nout=None, trim=False):
         """
         Return a WakeFunction object from the impedance data.
@@ -638,36 +728,37 @@ class Impedance(ComplexData):
             If a float is given, the pseudo wake function is trimmed from 
             time <= trim to 0. 
         """
-        
-        Z0 = (self.data['real'] + self.data['imag']*1j)
+
+        Z0 = (self.data['real'] + self.data['imag'] * 1j)
         Z = Z0[~np.isnan(Z0)]
-        
+
         if self.component_type != "long":
             Z = Z / 1j
-        
+
         freq = Z.index
-        fs = ( freq[-1] - freq[0] ) / len(freq)
+        fs = (freq[-1] - freq[0]) / len(freq)
         sampling = freq[1] - freq[0]
-        
+
         if nout is None:
             nout = len(Z)
         else:
             nout = int(nout)
-            
+
         time_array = sc.fft.fftfreq(nout, sampling)
         Wlong_raw = sc.fft.irfft(np.array(Z), n=nout, axis=0) * nout * fs
-        
+
         time = sc.fft.fftshift(time_array)
         Wlong = sc.fft.fftshift(Wlong_raw)
-        
+
         if trim is not False:
-            i_neg = np.where(time<trim)[0]
+            i_neg = np.where(time < trim)[0]
             Wlong[i_neg] = 0
-                    
-        wf = WakeFunction(variable=time, function=Wlong, 
+
+        wf = WakeFunction(variable=time,
+                          function=Wlong,
                           component_type=self.component_type)
         return wf
-    
+
     def plot(self):
         """
         Plot the impedance data.
@@ -686,7 +777,8 @@ class Impedance(ComplexData):
         label = r"$Z_{" + self.component_type + r"}$" + unit
         ax.set_ylabel(label)
         return ax
-    
+
+
 class WakeField:
     """
     Defines a WakeField which corresponds to a single physical element which 
@@ -725,26 +817,25 @@ class WakeField:
     load(file)
         Load WakeField element from file.
     """
-
     def __init__(self, structure_list=None, name=None):
         self.list_to_attr(structure_list)
         self.name = name
-        
+
     def __repr__(self):
         """Return representation of WakeField components."""
         return f'WakeField {self.name} with components:\n {(list(self.components))!r}'
-    
+
     def __radd__(self, structure_to_add):
         return self._add(structure_to_add)
 
     def __add__(self, structure_to_add):
         return self._add(structure_to_add)
-    
+
     def __iter__(self):
         """Iterate over components"""
         comp = [getattr(self, comp) for comp in self.components]
         return comp.__iter__()
-    
+
     def _add(self, structure_to_add):
         """Allow to add two WakeField element with different components."""
         for comp in structure_to_add.components:
@@ -775,8 +866,10 @@ class WakeField:
             else:
                 self.__setattr__(attribute_name, structure_to_add)
         else:
-            raise ValueError("{} is not an Impedance nor a WakeFunction.".format(structure_to_add))
-    
+            raise ValueError(
+                "{} is not an Impedance nor a WakeFunction.".format(
+                    structure_to_add))
+
     def list_to_attr(self, structure_list):
         """
         Add list of Impedance/WakeFunction components to WakeField.
@@ -788,32 +881,38 @@ class WakeField:
         if structure_list is not None:
             for component in structure_list:
                 self.append_to_model(component)
-    
+
     @property
     def impedance_components(self):
         """
         Return an array of the impedance component names for the element.
         """
-        valid = ["Zlong", "Zxdip", "Zydip", "Zxquad", "Zyquad", "Zxcst", "Zycst"]
+        valid = [
+            "Zlong", "Zxdip", "Zydip", "Zxquad", "Zyquad", "Zxcst", "Zycst"
+        ]
         return np.array([comp for comp in dir(self) if comp in valid])
-    
+
     @property
     def wake_components(self):
         """
         Return an array of the wake function component names for the element.
         """
-        valid = ["Wlong", "Wxdip", "Wydip", "Wxquad", "Wyquad", "Wxcst", "Wycst"]
+        valid = [
+            "Wlong", "Wxdip", "Wydip", "Wxquad", "Wyquad", "Wxcst", "Wycst"
+        ]
         return np.array([comp for comp in dir(self) if comp in valid])
-    
+
     @property
     def components(self):
         """
         Return an array of the component names for the element.
         """
-        valid = ["Wlong", "Wxdip", "Wydip", "Wxquad", "Wyquad", "Wxcst", "Wycst",
-                 "Zlong", "Zxdip", "Zydip", "Zxquad", "Zyquad", "Zxcst", "Zycst"]
+        valid = [
+            "Wlong", "Wxdip", "Wydip", "Wxquad", "Wyquad", "Wxcst", "Wycst",
+            "Zlong", "Zxdip", "Zydip", "Zxquad", "Zyquad", "Zxcst", "Zycst"
+        ]
         return np.array([comp for comp in dir(self) if comp in valid])
-    
+
     def drop(self, component):
         """
         Delete a component or a list of component from the WakeField.
@@ -834,7 +933,7 @@ class WakeField:
             component = self.impedance_components
         elif component == "W":
             component = self.wake_components
-        
+
         if isinstance(component, str):
             delattr(self, component)
         elif isinstance(component, list) or isinstance(component, np.ndarray):
@@ -842,7 +941,7 @@ class WakeField:
                 delattr(self, comp)
         else:
             raise TypeError("component should be a str or list of str.")
-            
+
     def save(self, file):
         """
         Save WakeField element to file.
@@ -860,9 +959,9 @@ class WakeField:
         None.
 
         """
-        with open(file,"wb") as f:
+        with open(file, "wb") as f:
             pickle.dump(self, f)
-            
+
     @staticmethod
     def load(file):
         """
@@ -884,9 +983,9 @@ class WakeField:
         """
         with open(file, 'rb') as f:
             wakefield = pickle.load(f)
-            
+
         return wakefield
-    
+
     @staticmethod
     def add_wakefields(wake1, beta1, wake2, beta2):
         """
@@ -912,23 +1011,20 @@ class WakeField:
         wake_sum = deepcopy(wake1)
         for component_name1 in wake1.components:
             comp1 = getattr(wake_sum, component_name1)
-            weight1 = ((beta1[0] ** comp1.power_x) * 
-                      (beta1[1] ** comp1.power_y))
-            setattr(wake_sum, component_name1, weight1*comp1)
-            
-        for component_name2 in wake2.components: 
+            weight1 = ((beta1[0]**comp1.power_x) * (beta1[1]**comp1.power_y))
+            setattr(wake_sum, component_name1, weight1 * comp1)
+
+        for component_name2 in wake2.components:
             comp2 = getattr(wake2, component_name2)
-            weight2 = ((beta2[0] ** comp2.power_x) * 
-                      (beta2[1] ** comp2.power_y))
+            weight2 = ((beta2[0]**comp2.power_x) * (beta2[1]**comp2.power_y))
             try:
                 comp1 = getattr(wake_sum, component_name2)
-                setattr(wake_sum, component_name2, comp1 +
-                        weight2*comp2)
+                setattr(wake_sum, component_name2, comp1 + weight2*comp2)
             except AttributeError:
-                setattr(wake_sum, component_name2, weight2*comp2)
+                setattr(wake_sum, component_name2, weight2 * comp2)
 
         return wake_sum
-    
+
     @staticmethod
     def add_several_wakefields(wakefields, beta):
         """
@@ -950,12 +1046,12 @@ class WakeField:
         if len(wakefields) == 1:
             return wakefields[0]
         elif len(wakefields) > 1:
-            wake_sum = WakeField.add_wakefields(wakefields[0], beta[:,0],
-                                     wakefields[1], beta[:,1])
+            wake_sum = WakeField.add_wakefields(wakefields[0], beta[:, 0],
+                                                wakefields[1], beta[:, 1])
             for i in range(len(wakefields) - 2):
-                wake_sum = WakeField.add_wakefields(wake_sum, [1 ,1], 
-                                         wakefields[i+2], beta[:,i+2])
+                wake_sum = WakeField.add_wakefields(wake_sum, [1, 1],
+                                                    wakefields[i + 2],
+                                                    beta[:, i + 2])
             return wake_sum
         else:
             raise ValueError("Error in input.")
-        
\ No newline at end of file
diff --git a/mbtrack2/instability/__init__.py b/mbtrack2/instability/__init__.py
index 5a5ee82d8c8fe0de0b26e3bfe6d4042fca46e8f4..f7a1516f07a5ba85219041527106425accf715ce 100644
--- a/mbtrack2/instability/__init__.py
+++ b/mbtrack2/instability/__init__.py
@@ -1,12 +1,16 @@
 # -*- coding: utf-8 -*-
-from mbtrack2.instability.ions import (ion_cross_section,
-                                       ion_frequency,
-                                       fast_beam_ion,
-                                       plot_critical_mass)
-from mbtrack2.instability.instabilities import (mbi_threshold,
-                                                cbi_threshold,
-                                                lcbi_growth_rate_mode,
-                                                lcbi_growth_rate,
-                                                rwmbi_growth_rate,
-                                                rwmbi_threshold,
-                                                lcbi_stability_diagram)
\ No newline at end of file
+from mbtrack2.instability.instabilities import (
+    cbi_threshold,
+    lcbi_growth_rate,
+    lcbi_growth_rate_mode,
+    lcbi_stability_diagram,
+    mbi_threshold,
+    rwmbi_growth_rate,
+    rwmbi_threshold,
+)
+from mbtrack2.instability.ions import (
+    fast_beam_ion,
+    ion_cross_section,
+    ion_frequency,
+    plot_critical_mass,
+)
diff --git a/mbtrack2/instability/instabilities.py b/mbtrack2/instability/instabilities.py
index b5e62461ca720fdcf8dd76322065f02bdebfd7e2..12e2237a9a4855689b3fb3b093a7adcb5333fcf3 100644
--- a/mbtrack2/instability/instabilities.py
+++ b/mbtrack2/instability/instabilities.py
@@ -3,11 +3,13 @@
 General calculations about instability thresholds.
 """
 
-import numpy as np
-import matplotlib.pyplot as plt
-from scipy.constants import c, m_e, e, pi, epsilon_0, mu_0
 import math
 
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.constants import c, e, epsilon_0, m_e, mu_0, pi
+
+
 def mbi_threshold(ring, sigma, R, b):
     """
     Compute the microbunching instability (MBI) threshold for a bunched beam
@@ -33,17 +35,18 @@ def mbi_threshold(ring, sigma, R, b):
     [2] : D. Zhou, "Coherent synchrotron radiation and microwave instability in 
     electron storage rings", PhD thesis, p112
     """
-    
+
     sigma = sigma * c
-    Ia = 4*pi*epsilon_0*m_e*c**3/e # Alfven current
-    chi = sigma*(R/b**3)**(1/2) # Shielding paramter
+    Ia = 4 * pi * epsilon_0 * m_e * c**3 / e  # Alfven current
+    chi = sigma * (R / b**3)**(1 / 2)  # Shielding paramter
     xi = 0.5 + 0.34*chi
     N = (ring.L * Ia * ring.ac * ring.gamma * ring.sigma_delta**2 * xi *
-        sigma**(1/3) / ( c * e * R**(1/3) ))
-    I = N*e/ring.T0
-    
+         sigma**(1 / 3) / (c * e * R**(1 / 3)))
+    I = N * e / ring.T0
+
     return I
 
+
 def cbi_threshold(ring, I, Vrf, f, beta, Ncav=1):
     """
     Compute the longitudinal and transverse coupled bunch instability 
@@ -81,15 +84,25 @@ def cbi_threshold(ring, I, Vrf, f, beta, Ncav=1):
     Instabilities in Electron Storage Rings Driven By Quadrupole Higher Order 
     Modes." 7th Int. Particle Accelerator Conf.(IPAC'16), Busan, Korea. 
     """
-    
-    fs = ring.synchrotron_tune(Vrf)*ring.f0
-    Zlong = fs/(f*ring.ac*ring.tau[2]) * (2*ring.E0) / (ring.f0 * I * Ncav)
-    Zxdip = 1/(ring.tau[0]*beta[0]) * (2*ring.E0) / (ring.f0 * I * Ncav)
-    Zydip = 1/(ring.tau[1]*beta[1]) * (2*ring.E0) / (ring.f0 * I * Ncav)
-    
+
+    fs = ring.synchrotron_tune(Vrf) * ring.f0
+    Zlong = fs / (f * ring.ac * ring.tau[2]) * (2 * ring.E0) / (ring.f0 * I *
+                                                                Ncav)
+    Zxdip = 1 / (ring.tau[0] * beta[0]) * (2 * ring.E0) / (ring.f0 * I * Ncav)
+    Zydip = 1 / (ring.tau[1] * beta[1]) * (2 * ring.E0) / (ring.f0 * I * Ncav)
+
     return (Zlong, Zxdip, Zydip)
 
-def lcbi_growth_rate_mode(ring, I, Vrf, M, mu, fr=None, RL=None, QL=None, Z=None):
+
+def lcbi_growth_rate_mode(ring,
+                          I,
+                          Vrf,
+                          M,
+                          mu,
+                          fr=None,
+                          RL=None,
+                          QL=None,
+                          Z=None):
     """
     Compute the longitudinal coupled bunch instability growth rate driven by
     an impedance for a given coupled bunch mode mu [1].
@@ -131,7 +144,7 @@ def lcbi_growth_rate_mode(ring, I, Vrf, M, mu, fr=None, RL=None, QL=None, Z=None
 
     nu_s = ring.synchrotron_tune(Vrf)
     factor = ring.eta() * I / (4 * np.pi * ring.E0 * nu_s)
-    
+
     if isinstance(fr, (float, int)):
         fr = np.array([fr])
     elif isinstance(fr, list):
@@ -144,30 +157,35 @@ def lcbi_growth_rate_mode(ring, I, Vrf, M, mu, fr=None, RL=None, QL=None, Z=None
         QL = np.array([QL])
     elif isinstance(QL, list):
         QL = np.array(QL)
-        
+
     if Z is None:
         omega_r = 2 * np.pi * fr
         n_max = int(10 * omega_r.max() / (ring.omega0 * M))
+
         def Zr(omega):
             Z = 0
             for i in range(len(fr)):
-                Z += np.real(RL[i] / (1 + 1j * QL[i] * (omega_r[i]/omega - omega/omega_r[i])))
+                Z += np.real(RL[i] /
+                             (1 + 1j * QL[i] *
+                              (omega_r[i] / omega - omega / omega_r[i])))
             return Z
     else:
         fmax = Z.data.index.max()
         n_max = int(2 * np.pi * fmax / (ring.omega0 * M))
+
         def Zr(omega):
-            return np.real( Z( omega / (2*np.pi) ) )
-        
+            return np.real(Z(omega / (2 * np.pi)))
+
     n0 = np.arange(n_max)
     n1 = np.arange(1, n_max)
-    omega_p = ring.omega0 * (n0 * M + mu + nu_s)
-    omega_m = ring.omega0 * (n1 * M - mu - nu_s)
-        
-    sum_val = np.sum(omega_p*Zr(omega_p)) - np.sum(omega_m*Zr(omega_m))
+    omega_p = ring.omega0 * (n0*M + mu + nu_s)
+    omega_m = ring.omega0 * (n1*M - mu - nu_s)
+
+    sum_val = np.sum(omega_p * Zr(omega_p)) - np.sum(omega_m * Zr(omega_m))
 
     return factor * sum_val
-    
+
+
 def lcbi_growth_rate(ring, I, Vrf, M, fr=None, RL=None, QL=None, Z=None):
     """
     Compute the maximum growth rate for longitudinal coupled bunch instability 
@@ -212,13 +230,22 @@ def lcbi_growth_rate(ring, I, Vrf, M, fr=None, RL=None, QL=None, Z=None):
     """
     growth_rates = np.zeros(M)
     for i in range(M):
-        growth_rates[i] = lcbi_growth_rate_mode(ring, I, Vrf, M, i, fr=fr, RL=RL, QL=QL, Z=Z)
-    
+        growth_rates[i] = lcbi_growth_rate_mode(ring,
+                                                I,
+                                                Vrf,
+                                                M,
+                                                i,
+                                                fr=fr,
+                                                RL=RL,
+                                                QL=QL,
+                                                Z=Z)
+
     growth_rate = np.max(growth_rates)
     mu = np.argmax(growth_rates)
-    
+
     return growth_rate, mu, growth_rates
 
+
 def lcbi_stability_diagram(ring, I, Vrf, M, modes, cavity_list, detune_range):
     """
     Plot longitudinal coupled bunch instability stability diagram for a 
@@ -254,29 +281,46 @@ def lcbi_stability_diagram(ring, I, Vrf, M, modes, cavity_list, detune_range):
         Show the shunt impedance threshold for different coupled bunch modes.
 
     """
-    
+
     Rth = np.zeros_like(detune_range)
     fig, ax = plt.subplots()
 
     for mu in modes:
         fixed_gr = 0
         for cav in cavity_list[:-1]:
-            fixed_gr += lcbi_growth_rate_mode(ring, I=I, Vrf=Vrf, mu=mu, fr=cav.fr, RL=cav.RL, QL=cav.QL, M=M)
-        
+            fixed_gr += lcbi_growth_rate_mode(ring,
+                                              I=I,
+                                              Vrf=Vrf,
+                                              mu=mu,
+                                              fr=cav.fr,
+                                              RL=cav.RL,
+                                              QL=cav.QL,
+                                              M=M)
+
         cav = cavity_list[-1]
         for i, det in enumerate(detune_range):
-            gr = lcbi_growth_rate_mode(ring, I=I, Vrf=Vrf, mu=mu, fr=cav.m*ring.f1 + det, RL=cav.RL, QL=cav.QL, M=M)
-            Rth[i] = (1/ring.tau[2] - fixed_gr) * cav.Rs / gr
+            gr = lcbi_growth_rate_mode(ring,
+                                       I=I,
+                                       Vrf=Vrf,
+                                       mu=mu,
+                                       fr=cav.m * ring.f1 + det,
+                                       RL=cav.RL,
+                                       QL=cav.QL,
+                                       M=M)
+            Rth[i] = (1 / ring.tau[2] - fixed_gr) * cav.Rs / gr
+
+        ax.plot(detune_range * 1e-3,
+                Rth * 1e-6,
+                label="$\mu$ = " + str(int(mu)))
 
-        ax.plot(detune_range*1e-3, Rth*1e-6, label="$\mu$ = " + str(int(mu)))
-    
     plt.xlabel(r"$\Delta f$ [kHz]")
     plt.ylabel(r"$R_{s,max}$ $[M\Omega]$")
     plt.yscale("log")
     plt.legend()
-        
+
     return fig
-    
+
+
 def rwmbi_growth_rate(ring, current, beff, rho_material, plane='x'):
     """
     Compute the growth rate of the transverse coupled-bunch instability induced
@@ -301,19 +345,21 @@ def rwmbi_growth_rate(ring, current, beff, rho_material, plane='x'):
     diffraction-limited storage ring", J. Synchrotron Rad. Vol 21, 2014. pp.937-960 
 
     """
-    plane_dict = {'x':0, 'y':1}
+    plane_dict = {'x': 0, 'y': 1}
     index = plane_dict[plane]
     beta0 = ring.optics.local_beta[index]
     omega0 = ring.omega0
     E0 = ring.E0
-    R = ring.L/(2*np.pi)
+    R = ring.L / (2 * np.pi)
     frac_tune, int_tune = math.modf(ring.tune[index])
-    Z0 = mu_0*c
-    
-    gr = (beta0*omega0*current*R) /(4*np.pi*E0*beff**3) * ((2*c*Z0*rho_material) / ((1-frac_tune)*omega0))**0.5
-    
+    Z0 = mu_0 * c
+
+    gr = (beta0*omega0*current*R) / (4 * np.pi * E0 * beff**3) * (
+        (2*c*Z0*rho_material) / ((1-frac_tune) * omega0))**0.5
+
     return gr
 
+
 def rwmbi_threshold(ring, beff, rho_material, plane='x'):
     """
     Compute the threshold current of the transverse coupled-bunch instability 
@@ -336,16 +382,16 @@ def rwmbi_threshold(ring, beff, rho_material, plane='x'):
     diffraction-limited storage ring", J. Synchrotron Rad. Vol 21, 2014. pp.937-960 
 
     """
-    plane_dict = {'x':0, 'y':1}
+    plane_dict = {'x': 0, 'y': 1}
     index = plane_dict[plane]
     beta0 = ring.optics.local_beta[index]
     omega0 = ring.omega0
     E0 = ring.E0
     tau_rad = ring.tau[index]
     frac_tune, int_tune = math.modf(ring.tune[index])
-    Z0 = mu_0*c
-    
-    Ith = (4*np.pi*E0*beff**3) / (c*beta0*tau_rad) * (((1-frac_tune)*omega0) / (2*c*Z0*rho_material))**0.5
-    
+    Z0 = mu_0 * c
+
+    Ith = (4 * np.pi * E0 * beff**3) / (c*beta0*tau_rad) * ((
+        (1-frac_tune) * omega0) / (2*c*Z0*rho_material))**0.5
+
     return Ith
-       
\ No newline at end of file
diff --git a/mbtrack2/instability/ions.py b/mbtrack2/instability/ions.py
index d6289295d4124b4674af9013c3af397e56f19af7..2ca737772d467d29b91208b443c615abe0bef25a 100644
--- a/mbtrack2/instability/ions.py
+++ b/mbtrack2/instability/ions.py
@@ -4,13 +4,24 @@ Various calculations about ion trapping and instabilities in electron storage
 rings.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
-from scipy.constants import c, m_e, e, pi, epsilon_0, hbar, Boltzmann, physical_constants, m_p
+import numpy as np
+from scipy.constants import (
+    Boltzmann,
+    c,
+    e,
+    epsilon_0,
+    hbar,
+    m_e,
+    m_p,
+    physical_constants,
+    pi,
+)
 
-rp = 1/(4*pi*epsilon_0) * e**2 / (m_p * c**2)
+rp = 1 / (4*pi*epsilon_0) * e**2 / (m_p * c**2)
 re = physical_constants["classical electron radius"][0]
 
+
 def ion_cross_section(ring, ion):
     """
     Compute the collisional ionization cross section.
@@ -49,12 +60,22 @@ def ion_cross_section(ring, ion):
         C0 = 8.1
     else:
         raise NotImplementedError
-        
-    sigma = 4*pi*(hbar/m_e/c)**2 * (M02 *(1/ring.beta**2 * np.log(ring.beta**2/(1-ring.beta**2)) - 1) + C0/ring.beta**2)
-    
+
+    sigma = 4 * pi * (hbar / m_e / c)**2 * (
+        M02 * (1 / ring.beta**2 * np.log(ring.beta**2 /
+                                         (1 - ring.beta**2)) - 1) +
+        C0 / ring.beta**2)
+
     return sigma
 
-def ion_frequency(N, Lsep, sigmax, sigmay, ion="CO", dim="y", express="coupling"):
+
+def ion_frequency(N,
+                  Lsep,
+                  sigmax,
+                  sigmay,
+                  ion="CO",
+                  dim="y",
+                  express="coupling"):
     """
     Compute the ion oscillation frequnecy.
 
@@ -92,7 +113,7 @@ def ion_frequency(N, Lsep, sigmax, sigmay, ion="CO", dim="y", express="coupling"
     instability. II. effect of ion decoherence", Physical Review E 52 (1995).
 
     """
-    
+
     if ion == "CO":
         A = 28
     elif ion == "H2":
@@ -103,25 +124,38 @@ def ion_frequency(N, Lsep, sigmax, sigmay, ion="CO", dim="y", express="coupling"
         A = 16
     elif ion == "CO2":
         A = 44
-    
+
     if dim == "y":
         pass
     elif dim == "x":
         sigmay, sigmax = sigmax, sigmay
     else:
         raise ValueError
-        
+
     if express == "coupling":
-        k = 3/2
+        k = 3 / 2
     elif express == "no_coupling":
         k = 1
-        
-    f = c * np.sqrt( 2 * rp * N / ( A * k * Lsep * sigmay * (sigmax + sigmay) ) ) / (2*pi)
-    
+
+    f = c * np.sqrt(2 * rp * N / (A * k * Lsep * sigmay *
+                                  (sigmax+sigmay))) / (2*pi)
+
     return f
-    
-def fast_beam_ion(ring, Nb, nb, Lsep, sigmax, sigmay, P, T, beta, 
-                  model="linear", delta_omega = 0, ion="CO", dim="y"):
+
+
+def fast_beam_ion(ring,
+                  Nb,
+                  nb,
+                  Lsep,
+                  sigmax,
+                  sigmay,
+                  P,
+                  T,
+                  beta,
+                  model="linear",
+                  delta_omega=0,
+                  ion="CO",
+                  dim="y"):
     """
     Compute fast beam ion instability rise time [1].
     
@@ -195,32 +229,35 @@ def fast_beam_ion(ring, Nb, nb, Lsep, sigmax, sigmay, P, T, beta,
         sigmay, sigmax = sigmax, sigmay
     else:
         raise ValueError
-        
+
     if ion == "CO":
         A = 28
     elif ion == "H2":
         A = 2
-        
+
     sigma_i = ion_cross_section(ring, ion)
-    
-    d_gas = P/(Boltzmann*T)
-    
-    num = 4 * d_gas * sigma_i * beta * Nb**(3/2) * nb**2 * re * rp**(1/2) * Lsep**(1/2) * c
-    den = 3 * np.sqrt(3) * ring.gamma * sigmay**(3/2) * (sigmay + sigmax)**(3/2) * A**(1/2)
-    
-    tau = den/num
-    
+
+    d_gas = P / (Boltzmann*T)
+
+    num = 4 * d_gas * sigma_i * beta * Nb**(3 / 2) * nb**2 * re * rp**(
+        1 / 2) * Lsep**(1 / 2) * c
+    den = 3 * np.sqrt(3) * ring.gamma * sigmay**(3 / 2) * (sigmay + sigmax)**(
+        3 / 2) * A**(1 / 2)
+
+    tau = den / num
+
     if model == "decoherence":
         tau = tau * 2 * np.sqrt(2) * nb * Lsep * delta_omega / c
     elif model == "non-linear":
         fi = ion_frequency(Nb, Lsep, sigmax, sigmay, ion, dim)
-        tau = tau * 2 * pi * fi * ring.T1 * nb**(3/2)
+        tau = tau * 2 * pi * fi * ring.T1 * nb**(3 / 2)
     elif model == "linear":
         pass
     else:
         raise ValueError("model unknown")
-    
-    return tau 
+
+    return tau
+
 
 def plot_critical_mass(ring, bunch_charge, bunch_spacing, n_points=1e4):
     """
@@ -247,28 +284,29 @@ def plot_critical_mass(ring, bunch_charge, bunch_spacing, n_points=1e4):
     Université Paris-Saclay).
 
     """
-    
+
     n_points = int(n_points)
     s = np.linspace(0, ring.L, n_points)
     sigma = ring.sigma(s)
-    N = np.abs(bunch_charge/e)
-    
-    Ay = N*rp*bunch_spacing*c/(2*sigma[2,:]*(sigma[2,:] + sigma[0,:]))
-    Ax = N*rp*bunch_spacing*c/(2*sigma[0,:]*(sigma[2,:] + sigma[0,:]))
-    
+    N = np.abs(bunch_charge / e)
+
+    Ay = N * rp * bunch_spacing * c / (2 * sigma[2, :] *
+                                       (sigma[2, :] + sigma[0, :]))
+    Ax = N * rp * bunch_spacing * c / (2 * sigma[0, :] *
+                                       (sigma[2, :] + sigma[0, :]))
+
     fig = plt.figure()
     ax = plt.gca()
     ax.plot(s, Ax, label=r"$A_x^c$")
     ax.plot(s, Ay, label=r"$A_y^c$")
     ax.set_yscale("log")
-    ax.plot(s, np.ones_like(s)*2, label=r"$H_2^+$")
-    ax.plot(s, np.ones_like(s)*16, label=r"$H_2O^+$")
-    ax.plot(s, np.ones_like(s)*18, label=r"$CH_4^+$")
-    ax.plot(s, np.ones_like(s)*28, label=r"$CO^+$")
-    ax.plot(s, np.ones_like(s)*44, label=r"$CO_2^+$")
+    ax.plot(s, np.ones_like(s) * 2, label=r"$H_2^+$")
+    ax.plot(s, np.ones_like(s) * 16, label=r"$H_2O^+$")
+    ax.plot(s, np.ones_like(s) * 18, label=r"$CH_4^+$")
+    ax.plot(s, np.ones_like(s) * 28, label=r"$CO^+$")
+    ax.plot(s, np.ones_like(s) * 44, label=r"$CO_2^+$")
     ax.legend()
     ax.set_ylabel("Critical mass")
     ax.set_xlabel("Longitudinal position [m]")
-    
+
     return fig
-    
\ No newline at end of file
diff --git a/mbtrack2/tracking/__init__.py b/mbtrack2/tracking/__init__.py
index 6f72151ebe7cf8ce7a8e72a12395385fc6d08fc1..02843e280f5cd945178b040b903c468b095765cc 100644
--- a/mbtrack2/tracking/__init__.py
+++ b/mbtrack2/tracking/__init__.py
@@ -1,31 +1,33 @@
 # -*- coding: utf-8 -*-
-from mbtrack2.tracking.particles import (Electron, 
-                                         Proton, 
-                                         Bunch, 
-                                         Beam, 
-                                         Particle)
-from mbtrack2.tracking.synchrotron import Synchrotron
-from mbtrack2.tracking.rf import (RFCavity, 
-                                  CavityResonator,
-                                  ProportionalLoop,
-                                  TunerLoop,
-                                  ProportionalIntegralLoop,
-                                  DirectFeedback)
-from mbtrack2.tracking.parallel import Mpi
-from mbtrack2.tracking.element import (Element, 
-                                       LongitudinalMap, 
-                                       TransverseMap, 
-                                       SynchrotronRadiation,
-                                       SkewQuadrupole,
-                                       TransverseMapSector,
-                                       transverse_map_sector_generator)
-from mbtrack2.tracking.aperture import (CircularAperture, 
-                                        ElipticalAperture,
-                                        RectangularAperture, 
-                                        LongitudinalAperture)
-from mbtrack2.tracking.wakepotential import (WakePotential, 
-                                             LongRangeResistiveWall)
-                                       
-from mbtrack2.tracking.feedback import (ExponentialDamper,
-                                        FIRDamper)
+from mbtrack2.tracking.aperture import (
+    CircularAperture,
+    ElipticalAperture,
+    LongitudinalAperture,
+    RectangularAperture,
+)
+from mbtrack2.tracking.element import (
+    Element,
+    LongitudinalMap,
+    SkewQuadrupole,
+    SynchrotronRadiation,
+    TransverseMap,
+    TransverseMapSector,
+    transverse_map_sector_generator,
+)
+from mbtrack2.tracking.feedback import ExponentialDamper, FIRDamper
 from mbtrack2.tracking.monitors import *
+from mbtrack2.tracking.parallel import Mpi
+from mbtrack2.tracking.particles import Beam, Bunch, Electron, Particle, Proton
+from mbtrack2.tracking.rf import (
+    CavityResonator,
+    DirectFeedback,
+    ProportionalIntegralLoop,
+    ProportionalLoop,
+    RFCavity,
+    TunerLoop,
+)
+from mbtrack2.tracking.synchrotron import Synchrotron
+from mbtrack2.tracking.wakepotential import (
+    LongRangeResistiveWall,
+    WakePotential,
+)
diff --git a/mbtrack2/tracking/aperture.py b/mbtrack2/tracking/aperture.py
index 67b51bf5d8a7d72cc855c7d6b028f423725b93c7..7f3645c40789737ac82a119c56f50b659ad59650 100644
--- a/mbtrack2/tracking/aperture.py
+++ b/mbtrack2/tracking/aperture.py
@@ -4,8 +4,10 @@ This module defines aperture elements for tracking.
 """
 
 import numpy as np
+
 from mbtrack2.tracking.element import Element
 
+
 class CircularAperture(Element):
     """
     Circular aperture element. The particles which are outside of the circle 
@@ -16,12 +18,11 @@ class CircularAperture(Element):
     radius : float
         radius of the circle in [m]
     """
-    
     def __init__(self, radius):
         self.radius = radius
         self.radius_squared = radius**2
-    
-    @Element.parallel    
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -32,9 +33,11 @@ class CircularAperture(Element):
         ----------
         bunch : Bunch or Beam object
         """
-        alive = bunch.particles["x"]**2 + bunch.particles["y"]**2 < self.radius_squared
+        alive = bunch.particles["x"]**2 + bunch.particles[
+            "y"]**2 < self.radius_squared
         bunch.alive[~alive] = False
-        
+
+
 class ElipticalAperture(Element):
     """
     Eliptical aperture element. The particles which are outside of the elipse 
@@ -47,14 +50,13 @@ class ElipticalAperture(Element):
     Y_radius : float
         vertical radius of the elipse in [m]
     """
-    
     def __init__(self, X_radius, Y_radius):
         self.X_radius = X_radius
         self.X_radius_squared = X_radius**2
         self.Y_radius = Y_radius
-        self.Y_radius_squared =Y_radius**2
-    
-    @Element.parallel    
+        self.Y_radius_squared = Y_radius**2
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -65,10 +67,11 @@ class ElipticalAperture(Element):
         ----------
         bunch : Bunch or Beam object
         """
-        alive = (bunch.particles["x"]**2/self.X_radius_squared + 
-                 bunch.particles["y"]**2/self.Y_radius_squared < 1)
+        alive = (bunch.particles["x"]**2 / self.X_radius_squared +
+                 bunch.particles["y"]**2 / self.Y_radius_squared < 1)
         bunch.alive[~alive] = False
 
+
 class RectangularAperture(Element):
     """
     Rectangular aperture element. The particles which are outside of the 
@@ -85,15 +88,13 @@ class RectangularAperture(Element):
     Y_bottom : float, optional
         bottom vertical aperture of the rectangle in [m]
     """
-    
     def __init__(self, X_right, Y_top, X_left=None, Y_bottom=None):
         self.X_right = X_right
         self.X_left = X_left
         self.Y_top = Y_top
         self.Y_bottom = Y_bottom
- 
-    
-    @Element.parallel    
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -104,22 +105,23 @@ class RectangularAperture(Element):
         ----------
         bunch : Bunch or Beam object
         """
-        
+
         if (self.X_left is None):
             alive_X = np.abs(bunch.particles["x"]) < self.X_right
         else:
             alive_X = ((bunch.particles["x"] < self.X_right) &
-                     (bunch.particles["x"] > self.X_left))
-            
+                       (bunch.particles["x"] > self.X_left))
+
         if (self.Y_bottom is None):
             alive_Y = np.abs(bunch.particles["y"]) < self.Y_top
         else:
             alive_Y = ((bunch.particles["y"] < self.Y_top) &
-                     (bunch.particles["y"] > self.Y_bottom))
+                       (bunch.particles["y"] > self.Y_bottom))
 
         alive = alive_X & alive_Y
         bunch.alive[~alive] = False
-        
+
+
 class LongitudinalAperture(Element):
     """
     Longitudinal aperture element. The particles which are outside of the 
@@ -133,15 +135,14 @@ class LongitudinalAperture(Element):
     tau_low : float, optional
         Lower longitudinal bound in [s].
     """
-    
     def __init__(self, tau_up, tau_low=None):
         self.tau_up = tau_up
         if tau_low is None:
-            self.tau_low = tau_up*-1
+            self.tau_low = tau_up * -1
         else:
             self.tau_low = tau_low
-    
-    @Element.parallel    
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -152,8 +153,8 @@ class LongitudinalAperture(Element):
         ----------
         bunch : Bunch or Beam object
         """
-        
+
         alive = ((bunch.particles["tau"] < self.tau_up) &
                  (bunch.particles["tau"] > self.tau_low))
 
-        bunch.alive[~alive] = False
\ No newline at end of file
+        bunch.alive[~alive] = False
diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index 231ba253212e31b293df2a27fdee80d3431a6bc7..7ccfa4d8caae64fab0190b9414693df4887b52cd 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -5,18 +5,20 @@ an abstract base class which is to be used as mother class to every elements
 included in the tracking.
 """
 
-import numpy as np
 from abc import ABCMeta, abstractmethod
-from functools import wraps
 from copy import deepcopy
+from functools import wraps
+
+import numpy as np
+
 from mbtrack2.tracking.particles import Beam
 
+
 class Element(metaclass=ABCMeta):
     """
     Abstract Element class used for subclass inheritance to define all kinds 
     of objects which intervene in the tracking.
     """
-
     @abstractmethod
     def track(self, beam):
         """
@@ -28,7 +30,7 @@ class Element(metaclass=ABCMeta):
         beam : Beam object
         """
         raise NotImplementedError
-        
+
     @staticmethod
     def parallel(track):
         """
@@ -65,8 +67,10 @@ class Element(metaclass=ABCMeta):
                 self = args[0]
                 bunch = args[1]
                 track(self, bunch, *args[2:], **kwargs)
+
         return track_wrapper
-        
+
+
 class LongitudinalMap(Element):
     """
     Longitudinal map for a single turn in the synchrotron.
@@ -75,10 +79,9 @@ class LongitudinalMap(Element):
     ----------
     ring : Synchrotron object
     """
-    
     def __init__(self, ring):
         self.ring = ring
-        
+
     @Element.parallel
     def track(self, bunch):
         """
@@ -91,7 +94,9 @@ class LongitudinalMap(Element):
         bunch : Bunch or Beam object
         """
         bunch["delta"] -= self.ring.U0 / self.ring.E0
-        bunch["tau"] += self.ring.eta(bunch["delta"]) * self.ring.T0 * bunch["delta"]
+        bunch["tau"] += self.ring.eta(
+            bunch["delta"]) * self.ring.T0 * bunch["delta"]
+
 
 class SynchrotronRadiation(Element):
     """
@@ -104,12 +109,11 @@ class SynchrotronRadiation(Element):
     switch : bool array of shape (3,), optional
         allow to choose on which plane the synchrotron radiation is active
     """
-    
-    def __init__(self, ring, switch = np.ones((3,), dtype=bool)):
+    def __init__(self, ring, switch=np.ones((3, ), dtype=bool)):
         self.ring = ring
         self.switch = switch
-        
-    @Element.parallel        
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -122,19 +126,26 @@ class SynchrotronRadiation(Element):
         """
         if (self.switch[0] == True):
             rand = np.random.normal(size=len(bunch))
-            bunch["delta"] = ((1 - 2*self.ring.T0/self.ring.tau[2])*bunch["delta"] +
-                 2*self.ring.sigma_delta*(self.ring.T0/self.ring.tau[2])**0.5*rand)
-            
+            bunch["delta"] = (
+                (1 - 2 * self.ring.T0 / self.ring.tau[2]) * bunch["delta"] +
+                2 * self.ring.sigma_delta *
+                (self.ring.T0 / self.ring.tau[2])**0.5 * rand)
+
         if (self.switch[1] == True):
             rand = np.random.normal(size=len(bunch))
-            bunch["xp"] = ((1 - 2*self.ring.T0/self.ring.tau[0])*bunch["xp"] +
-                 2*self.ring.sigma()[1]*(self.ring.T0/self.ring.tau[0])**0.5*rand)
-       
+            bunch["xp"] = (
+                (1 - 2 * self.ring.T0 / self.ring.tau[0]) * bunch["xp"] +
+                2 * self.ring.sigma()[1] *
+                (self.ring.T0 / self.ring.tau[0])**0.5 * rand)
+
         if (self.switch[2] == True):
             rand = np.random.normal(size=len(bunch))
-            bunch["yp"] = ((1 - 2*self.ring.T0/self.ring.tau[1])*bunch["yp"] +
-                 2*self.ring.sigma()[3]*(self.ring.T0/self.ring.tau[1])**0.5*rand)
-        
+            bunch["yp"] = (
+                (1 - 2 * self.ring.T0 / self.ring.tau[1]) * bunch["yp"] +
+                2 * self.ring.sigma()[3] *
+                (self.ring.T0 / self.ring.tau[1])**0.5 * rand)
+
+
 class TransverseMap(Element):
     """
     Transverse map for a single turn in the synchrotron.
@@ -143,20 +154,21 @@ class TransverseMap(Element):
     ----------
     ring : Synchrotron object
     """
-    
     def __init__(self, ring):
         self.ring = ring
         self.alpha = self.ring.optics.local_alpha
         self.beta = self.ring.optics.local_beta
         self.gamma = self.ring.optics.local_gamma
-        self.dispersion = self.ring.optics.local_dispersion        
+        self.dispersion = self.ring.optics.local_dispersion
         if self.ring.adts is not None:
-            self.adts_poly = [np.poly1d(self.ring.adts[0]),
-                              np.poly1d(self.ring.adts[1]),
-                              np.poly1d(self.ring.adts[2]), 
-                              np.poly1d(self.ring.adts[3])]
-    
-    @Element.parallel    
+            self.adts_poly = [
+                np.poly1d(self.ring.adts[0]),
+                np.poly1d(self.ring.adts[1]),
+                np.poly1d(self.ring.adts[2]),
+                np.poly1d(self.ring.adts[3])
+            ]
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -170,10 +182,10 @@ class TransverseMap(Element):
 
         # Compute phase advance which depends on energy via chromaticity and ADTS
         if self.ring.adts is None:
-            phase_advance_x = 2*np.pi * (self.ring.tune[0] + 
-                                         self.ring.chro[0]*bunch["delta"])
-            phase_advance_y = 2*np.pi * (self.ring.tune[1] + 
-                                         self.ring.chro[1]*bunch["delta"])
+            phase_advance_x = 2 * np.pi * (self.ring.tune[0] +
+                                           self.ring.chro[0] * bunch["delta"])
+            phase_advance_y = 2 * np.pi * (self.ring.tune[1] +
+                                           self.ring.chro[1] * bunch["delta"])
         else:
             Jx = (self.ring.optics.local_gamma[0] * bunch['x']**2) + \
                   (2*self.ring.optics.local_alpha[0] * bunch['x']*bunch['xp']) + \
@@ -181,46 +193,53 @@ class TransverseMap(Element):
             Jy = (self.ring.optics.local_gamma[1] * bunch['y']**2) + \
                   (2*self.ring.optics.local_alpha[1] * bunch['y']*bunch['yp']) + \
                   (self.ring.optics.local_beta[1] * bunch['yp']**2)
-            phase_advance_x = 2*np.pi * (self.ring.tune[0] + 
-                                         self.ring.chro[0]*bunch["delta"] + 
-                                         self.adts_poly[0](Jx) + 
-                                         self.adts_poly[2](Jy))
-            phase_advance_y = 2*np.pi * (self.ring.tune[1] + 
-                                         self.ring.chro[1]*bunch["delta"] +
-                                         self.adts_poly[1](Jx) + 
-                                         self.adts_poly[3](Jy))
-        
+            phase_advance_x = 2 * np.pi * (
+                self.ring.tune[0] + self.ring.chro[0] * bunch["delta"] +
+                self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
+            phase_advance_y = 2 * np.pi * (
+                self.ring.tune[1] + self.ring.chro[1] * bunch["delta"] +
+                self.adts_poly[1](Jx) + self.adts_poly[3](Jy))
+
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
-        matrix = np.zeros((6,6,len(bunch)))
-        
+        matrix = np.zeros((6, 6, len(bunch)))
+
         # Horizontal
-        matrix[0,0,:] = np.cos(phase_advance_x) + self.alpha[0]*np.sin(phase_advance_x)
-        matrix[0,1,:] = self.beta[0]*np.sin(phase_advance_x)
-        matrix[0,2,:] = self.dispersion[0]
-        matrix[1,0,:] = -1*self.gamma[0]*np.sin(phase_advance_x)
-        matrix[1,1,:] = np.cos(phase_advance_x) - self.alpha[0]*np.sin(phase_advance_x)
-        matrix[1,2,:] = self.dispersion[1]
-        matrix[2,2,:] = 1
-        
+        matrix[0, 0, :] = np.cos(
+            phase_advance_x) + self.alpha[0] * np.sin(phase_advance_x)
+        matrix[0, 1, :] = self.beta[0] * np.sin(phase_advance_x)
+        matrix[0, 2, :] = self.dispersion[0]
+        matrix[1, 0, :] = -1 * self.gamma[0] * np.sin(phase_advance_x)
+        matrix[1, 1, :] = np.cos(
+            phase_advance_x) - self.alpha[0] * np.sin(phase_advance_x)
+        matrix[1, 2, :] = self.dispersion[1]
+        matrix[2, 2, :] = 1
+
         # Vertical
-        matrix[3,3,:] = np.cos(phase_advance_y) + self.alpha[1]*np.sin(phase_advance_y)
-        matrix[3,4,:] = self.beta[1]*np.sin(phase_advance_y)
-        matrix[3,5,:] = self.dispersion[2]
-        matrix[4,3,:] = -1*self.gamma[1]*np.sin(phase_advance_y)
-        matrix[4,4,:] = np.cos(phase_advance_y) - self.alpha[1]*np.sin(phase_advance_y)
-        matrix[4,5,:] = self.dispersion[3]
-        matrix[5,5,:] = 1
-        
-        x = matrix[0,0,:]*bunch["x"] + matrix[0,1,:]*bunch["xp"] + matrix[0,2,:]*bunch["delta"]
-        xp = matrix[1,0,:]*bunch["x"] + matrix[1,1,:]*bunch["xp"] + matrix[1,2,:]*bunch["delta"]
-        y =  matrix[3,3,:]*bunch["y"] + matrix[3,4,:]*bunch["yp"] + matrix[3,5,:]*bunch["delta"]
-        yp = matrix[4,3,:]*bunch["y"] + matrix[4,4,:]*bunch["yp"] + matrix[4,5,:]*bunch["delta"]
-        
+        matrix[3, 3, :] = np.cos(
+            phase_advance_y) + self.alpha[1] * np.sin(phase_advance_y)
+        matrix[3, 4, :] = self.beta[1] * np.sin(phase_advance_y)
+        matrix[3, 5, :] = self.dispersion[2]
+        matrix[4, 3, :] = -1 * self.gamma[1] * np.sin(phase_advance_y)
+        matrix[4, 4, :] = np.cos(
+            phase_advance_y) - self.alpha[1] * np.sin(phase_advance_y)
+        matrix[4, 5, :] = self.dispersion[3]
+        matrix[5, 5, :] = 1
+
+        x = matrix[0, 0, :] * bunch["x"] + matrix[
+            0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"]
+        xp = matrix[1, 0, :] * bunch["x"] + matrix[
+            1, 1, :] * bunch["xp"] + matrix[1, 2, :] * bunch["delta"]
+        y = matrix[3, 3, :] * bunch["y"] + matrix[
+            3, 4, :] * bunch["yp"] + matrix[3, 5, :] * bunch["delta"]
+        yp = matrix[4, 3, :] * bunch["y"] + matrix[
+            4, 4, :] * bunch["yp"] + matrix[4, 5, :] * bunch["delta"]
+
         bunch["x"] = x
         bunch["xp"] = xp
         bunch["y"] = y
         bunch["yp"] = yp
-        
+
+
 class SkewQuadrupole:
     """
     Thin skew quadrupole element used to introduce betatron coupling (the 
@@ -234,8 +253,8 @@ class SkewQuadrupole:
     """
     def __init__(self, strength):
         self.strength = strength
-        
-    @Element.parallel        
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -246,10 +265,11 @@ class SkewQuadrupole:
         ----------
         bunch : Bunch or Beam object
         """
-        
+
         bunch['xp'] = bunch['xp'] - self.strength * bunch['y']
         bunch['yp'] = bunch['yp'] - self.strength * bunch['x']
 
+
 class TransverseMapSector(Element):
     """
     Transverse map for a sector of the synchrotron, from an initial 
@@ -282,28 +302,39 @@ class TransverseMapSector(Element):
         for details. The default is None.
 
     """
-    def __init__(self, ring, alpha0, beta0, dispersion0, alpha1, beta1, 
-                 dispersion1, phase_diff, chro_diff, adts=None):
+    def __init__(self,
+                 ring,
+                 alpha0,
+                 beta0,
+                 dispersion0,
+                 alpha1,
+                 beta1,
+                 dispersion1,
+                 phase_diff,
+                 chro_diff,
+                 adts=None):
         self.ring = ring
         self.alpha0 = alpha0
         self.beta0 = beta0
-        self.gamma0 = (1 + self.alpha0**2)/self.beta0
+        self.gamma0 = (1 + self.alpha0**2) / self.beta0
         self.dispersion0 = dispersion0
         self.alpha1 = alpha1
         self.beta1 = beta1
-        self.gamma1 = (1 + self.alpha1**2)/self.beta1
-        self.dispersion1 = dispersion1  
-        self.tune_diff = phase_diff / (2*np.pi)
+        self.gamma1 = (1 + self.alpha1**2) / self.beta1
+        self.dispersion1 = dispersion1
+        self.tune_diff = phase_diff / (2 * np.pi)
         self.chro_diff = chro_diff
         if adts is not None:
-            self.adts_poly = [np.poly1d(adts[0]),
-                              np.poly1d(adts[1]),
-                              np.poly1d(adts[2]), 
-                              np.poly1d(adts[3])]
+            self.adts_poly = [
+                np.poly1d(adts[0]),
+                np.poly1d(adts[1]),
+                np.poly1d(adts[2]),
+                np.poly1d(adts[3])
+            ]
         else:
             self.adts_poly = None
-    
-    @Element.parallel    
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the element.
@@ -317,10 +348,10 @@ class TransverseMapSector(Element):
 
         # Compute phase advance which depends on energy via chromaticity and ADTS
         if self.adts_poly is None:
-            phase_advance_x = 2*np.pi * (self.tune_diff[0] + 
-                                         self.chro_diff[0]*bunch["delta"])
-            phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
-                                         self.chro_diff[1]*bunch["delta"])
+            phase_advance_x = 2 * np.pi * (self.tune_diff[0] +
+                                           self.chro_diff[0] * bunch["delta"])
+            phase_advance_y = 2 * np.pi * (self.tune_diff[1] +
+                                           self.chro_diff[1] * bunch["delta"])
         else:
             Jx = (self.gamma0[0] * bunch['x']**2) + \
                   (2*self.alpha0[0] * bunch['x']*self['xp']) + \
@@ -328,46 +359,69 @@ class TransverseMapSector(Element):
             Jy = (self.gamma0[1] * bunch['y']**2) + \
                   (2*self.alpha0[1] * bunch['y']*bunch['yp']) + \
                   (self.beta0[1] * bunch['yp']**2)
-            phase_advance_x = 2*np.pi * (self.tune_diff[0] + 
-                                         self.chro_diff[0]*bunch["delta"] + 
-                                         self.adts_poly[0](Jx) + 
-                                         self.adts_poly[2](Jy))
-            phase_advance_y = 2*np.pi * (self.tune_diff[1] + 
-                                         self.chro_diff[1]*bunch["delta"] +
-                                         self.adts_poly[1](Jx) + 
-                                         self.adts_poly[3](Jy))
-        
+            phase_advance_x = 2 * np.pi * (
+                self.tune_diff[0] + self.chro_diff[0] * bunch["delta"] +
+                self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
+            phase_advance_y = 2 * np.pi * (
+                self.tune_diff[1] + self.chro_diff[1] * bunch["delta"] +
+                self.adts_poly[1](Jx) + self.adts_poly[3](Jy))
+
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
-        matrix = np.zeros((6,6,len(bunch)))
-        
+        matrix = np.zeros((6, 6, len(bunch)))
+
         # Horizontal
-        matrix[0,0,:] = np.sqrt(self.beta1[0]/self.beta0[0])*(np.cos(phase_advance_x) + self.alpha0[0]*np.sin(phase_advance_x))
-        matrix[0,1,:] = np.sqrt(self.beta0[0]*self.beta1[0])*np.sin(phase_advance_x)
-        matrix[0,2,:] = self.dispersion1[0] - matrix[0,0,:]*self.dispersion0[0] - matrix[0,1,:]*self.dispersion0[1]
-        matrix[1,0,:] = ((self.alpha0[0] - self.alpha1[0])*np.cos(phase_advance_x) - (1 + self.alpha0[0]*self.alpha1[0])*np.sin(phase_advance_x))/np.sqrt(self.beta0[0]*self.beta1[0])
-        matrix[1,1,:] = np.sqrt(self.beta0[0]/self.beta1[0])*(np.cos(phase_advance_x) - self.alpha1[0]*np.sin(phase_advance_x))
-        matrix[1,2,:] = self.dispersion1[1] - matrix[1,0,:]*self.dispersion0[0] - matrix[1,1,:]*self.dispersion0[1]
-        matrix[2,2,:] = 1
-        
+        matrix[0, 0, :] = np.sqrt(self.beta1[0] / self.beta0[0]) * (
+            np.cos(phase_advance_x) + self.alpha0[0] * np.sin(phase_advance_x))
+        matrix[0, 1, :] = np.sqrt(
+            self.beta0[0] * self.beta1[0]) * np.sin(phase_advance_x)
+        matrix[0, 2, :] = self.dispersion1[0] - matrix[
+            0, 0, :] * self.dispersion0[0] - matrix[0,
+                                                    1, :] * self.dispersion0[1]
+        matrix[1, 0, :] = (
+            (self.alpha0[0] - self.alpha1[0]) * np.cos(phase_advance_x) -
+            (1 + self.alpha0[0] * self.alpha1[0]) *
+            np.sin(phase_advance_x)) / np.sqrt(self.beta0[0] * self.beta1[0])
+        matrix[1, 1, :] = np.sqrt(self.beta0[0] / self.beta1[0]) * (
+            np.cos(phase_advance_x) - self.alpha1[0] * np.sin(phase_advance_x))
+        matrix[1, 2, :] = self.dispersion1[1] - matrix[
+            1, 0, :] * self.dispersion0[0] - matrix[1,
+                                                    1, :] * self.dispersion0[1]
+        matrix[2, 2, :] = 1
+
         # Vertical
-        matrix[3,3,:] = np.sqrt(self.beta1[1]/self.beta0[1])*(np.cos(phase_advance_y) + self.alpha0[1]*np.sin(phase_advance_y))
-        matrix[3,4,:] = np.sqrt(self.beta0[1]*self.beta1[1])*np.sin(phase_advance_y)
-        matrix[3,5,:] = self.dispersion1[2] - matrix[3,3,:]*self.dispersion0[2] - matrix[3,4,:]*self.dispersion0[3]
-        matrix[4,3,:] = ((self.alpha0[1] - self.alpha1[1])*np.cos(phase_advance_y) - (1 + self.alpha0[1]*self.alpha1[1])*np.sin(phase_advance_y))/np.sqrt(self.beta0[1]*self.beta1[1])
-        matrix[4,4,:] = np.sqrt(self.beta0[1]/self.beta1[1])*(np.cos(phase_advance_y) - self.alpha1[1]*np.sin(phase_advance_y))
-        matrix[4,5,:] = self.dispersion1[3] - matrix[4,3,:]*self.dispersion0[2] - matrix[4,4,:]*self.dispersion0[3]
-        matrix[5,5,:] = 1
-        
-        x = matrix[0,0,:]*bunch["x"] + matrix[0,1,:]*bunch["xp"] + matrix[0,2,:]*bunch["delta"]
-        xp = matrix[1,0,:]*bunch["x"] + matrix[1,1,:]*bunch["xp"] + matrix[1,2,:]*bunch["delta"]
-        y =  matrix[3,3,:]*bunch["y"] + matrix[3,4,:]*bunch["yp"] + matrix[3,5,:]*bunch["delta"]
-        yp = matrix[4,3,:]*bunch["y"] + matrix[4,4,:]*bunch["yp"] + matrix[4,5,:]*bunch["delta"]
-        
+        matrix[3, 3, :] = np.sqrt(self.beta1[1] / self.beta0[1]) * (
+            np.cos(phase_advance_y) + self.alpha0[1] * np.sin(phase_advance_y))
+        matrix[3, 4, :] = np.sqrt(
+            self.beta0[1] * self.beta1[1]) * np.sin(phase_advance_y)
+        matrix[3, 5, :] = self.dispersion1[2] - matrix[
+            3, 3, :] * self.dispersion0[2] - matrix[3,
+                                                    4, :] * self.dispersion0[3]
+        matrix[4, 3, :] = (
+            (self.alpha0[1] - self.alpha1[1]) * np.cos(phase_advance_y) -
+            (1 + self.alpha0[1] * self.alpha1[1]) *
+            np.sin(phase_advance_y)) / np.sqrt(self.beta0[1] * self.beta1[1])
+        matrix[4, 4, :] = np.sqrt(self.beta0[1] / self.beta1[1]) * (
+            np.cos(phase_advance_y) - self.alpha1[1] * np.sin(phase_advance_y))
+        matrix[4, 5, :] = self.dispersion1[3] - matrix[
+            4, 3, :] * self.dispersion0[2] - matrix[4,
+                                                    4, :] * self.dispersion0[3]
+        matrix[5, 5, :] = 1
+
+        x = matrix[0, 0, :] * bunch["x"] + matrix[
+            0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"]
+        xp = matrix[1, 0, :] * bunch["x"] + matrix[
+            1, 1, :] * bunch["xp"] + matrix[1, 2, :] * bunch["delta"]
+        y = matrix[3, 3, :] * bunch["y"] + matrix[
+            3, 4, :] * bunch["yp"] + matrix[3, 5, :] * bunch["delta"]
+        yp = matrix[4, 3, :] * bunch["y"] + matrix[
+            4, 4, :] * bunch["yp"] + matrix[4, 5, :] * bunch["delta"]
+
         bunch["x"] = x
         bunch["xp"] = xp
         bunch["y"] = y
         bunch["yp"] = yp
-        
+
+
 def transverse_map_sector_generator(ring, positions):
     """
     Convenience function which generate a list of TransverseMapSector elements
@@ -393,30 +447,32 @@ def transverse_map_sector_generator(ring, positions):
 
     """
     import at
+
     def _compute_chro(ring, pos, dp=1e-4):
         lat = deepcopy(ring.optics.lattice)
         lat.append(at.Marker("END"))
-        N=len(lat)
-        refpts=np.arange(N)
+        N = len(lat)
+        refpts = np.arange(N)
         *elem_neg_dp, = at.linopt2(lat, refpts=refpts, dp=-dp)
         *elem_pos_dp, = at.linopt2(lat, refpts=refpts, dp=dp)
 
         s = elem_neg_dp[2]["s_pos"]
-        mux0 = elem_neg_dp[2]['mu'][:,0]
-        mux1 = elem_pos_dp[2]['mu'][:,0]
-        muy0 = elem_neg_dp[2]['mu'][:,1]
-        muy1 = elem_pos_dp[2]['mu'][:,1]
+        mux0 = elem_neg_dp[2]['mu'][:, 0]
+        mux1 = elem_pos_dp[2]['mu'][:, 0]
+        muy0 = elem_neg_dp[2]['mu'][:, 1]
+        muy1 = elem_pos_dp[2]['mu'][:, 1]
 
-        Chrox=(mux1-mux0)/(2*dp)/2/np.pi
-        Chroy=(muy1-muy0)/(2*dp)/2/np.pi
+        Chrox = (mux1-mux0) / (2*dp) / 2 / np.pi
+        Chroy = (muy1-muy0) / (2*dp) / 2 / np.pi
         chrox = np.interp(pos, s, Chrox)
         chroy = np.interp(pos, s, Chroy)
-        
+
         return np.array([chrox, chroy])
-    
+
     if ring.optics.use_local_values:
-        raise ValueError("The Synchrotron object must be loaded from an AT lattice")
-    
+        raise ValueError(
+            "The Synchrotron object must be loaded from an AT lattice")
+
     N_sec = len(positions)
     sectors = []
     for i in range(N_sec):
@@ -426,11 +482,11 @@ def transverse_map_sector_generator(ring, positions):
         mu0 = ring.optics.mu(positions[i])
         chro0 = _compute_chro(ring, positions[i])
         if i != (N_sec - 1):
-            alpha1 = ring.optics.alpha(positions[i+1])
-            beta1 = ring.optics.beta(positions[i+1])
-            dispersion1 = ring.optics.dispersion(positions[i+1])
-            mu1 = ring.optics.mu(positions[i+1])
-            chro1 = _compute_chro(ring, positions[i+1])
+            alpha1 = ring.optics.alpha(positions[i + 1])
+            beta1 = ring.optics.beta(positions[i + 1])
+            dispersion1 = ring.optics.dispersion(positions[i + 1])
+            mu1 = ring.optics.mu(positions[i + 1])
+            chro1 = _compute_chro(ring, positions[i + 1])
         else:
             alpha1 = ring.optics.alpha(positions[0])
             beta1 = ring.optics.beta(positions[0])
@@ -439,6 +495,7 @@ def transverse_map_sector_generator(ring, positions):
             chro1 = _compute_chro(ring, ring.L)
         phase_diff = mu1 - mu0
         chro_diff = chro1 - chro0
-        sectors.append(TransverseMapSector(ring, alpha0, beta0, dispersion0, 
-                     alpha1, beta1, dispersion1, phase_diff, chro_diff))
+        sectors.append(
+            TransverseMapSector(ring, alpha0, beta0, dispersion0, alpha1,
+                                beta1, dispersion1, phase_diff, chro_diff))
     return sectors
diff --git a/mbtrack2/tracking/feedback.py b/mbtrack2/tracking/feedback.py
index b3633c5aec7e710315bdef892a222adfb79e936c..72c4c9c2924e61a5c8aaf12de644d551e2038c18 100644
--- a/mbtrack2/tracking/feedback.py
+++ b/mbtrack2/tracking/feedback.py
@@ -3,11 +3,13 @@
 This module defines both exponential and FIR based bunch by bunch damper 
 feedback for tracking.
 """
-import numpy as np 
 import matplotlib.pyplot as plt
-from mbtrack2.tracking import Element, Beam, Bunch
+import numpy as np
 
-class ExponentialDamper(Element): 
+from mbtrack2.tracking import Beam, Bunch, Element
+
+
+class ExponentialDamper(Element):
     """ 
     Simple bunch by bunch damper feedback system based on exponential damping.
     
@@ -39,8 +41,8 @@ class ExponentialDamper(Element):
             self.mean_idx = 5
         else:
             raise ValueError(f"plane should be x, y or s, not {self.plane}")
-            
-    @Element.parallel 
+
+    @Element.parallel
     def track(self, bunch):
         """
         Tracking method for the feedback system
@@ -51,12 +53,12 @@ class ExponentialDamper(Element):
         ----------
         bunch : Bunch or Beam object
         """
-        bunch[self.action] -= (2*self.ring.T0/
-                               self.damping_time*
-                               np.sin(self.phase_diff)*
+        bunch[self.action] -= (2 * self.ring.T0 / self.damping_time *
+                               np.sin(self.phase_diff) *
                                bunch.mean[self.mean_idx])
 
-class FIRDamper(Element): 
+
+class FIRDamper(Element):
     """ 
     Bunch by bunch damper feedback system based on FIR filters.
     
@@ -114,10 +116,17 @@ class FIRDamper(Element):
     2004. Transverse bunch by bunch feedback system for the Spring-8 
     storage ring.
     """
-    
-    def __init__(self, ring, plane, tune, turn_delay, tap_number, gain, phase, 
-                 bpm_error=None, max_kick=None):
-        
+    def __init__(self,
+                 ring,
+                 plane,
+                 tune,
+                 turn_delay,
+                 tap_number,
+                 gain,
+                 phase,
+                 bpm_error=None,
+                 max_kick=None):
+
         self.ring = ring
         self.tune = tune
         self.turn_delay = turn_delay
@@ -127,7 +136,7 @@ class FIRDamper(Element):
         self.bpm_error = bpm_error
         self.max_kick = max_kick
         self.plane = plane
-        
+
         if self.plane == "x":
             self.action = "xp"
             self.damp_idx = 0
@@ -140,15 +149,15 @@ class FIRDamper(Element):
             self.action = "delta"
             self.damp_idx = 2
             self.mean_idx = 4
-            
+
         self.beam_no_mpi = False
-        
-        self.pos = np.zeros((self.tap_number,1))
-        self.kick = np.zeros((self.turn_delay+1,1))
-        self.coef = self.get_fir(self.tap_number, self.tune, self.phase, 
-                                   self.turn_delay, self.gain)
-        
-    def get_fir(self, tap_number, tune, phase, turn_delay, gain):        
+
+        self.pos = np.zeros((self.tap_number, 1))
+        self.kick = np.zeros((self.turn_delay + 1, 1))
+        self.coef = self.get_fir(self.tap_number, self.tune, self.phase,
+                                 self.turn_delay, self.gain)
+
+    def get_fir(self, tap_number, tune, phase, turn_delay, gain):
         """
         Compute the FIR coefficients.
         
@@ -159,29 +168,32 @@ class FIRDamper(Element):
         FIR_coef : array
             Array containing the FIR coefficients.
         """
-        it = np.zeros((tap_number,))
-        CC = np.zeros((5, tap_number,))
-        zeta = (phase*2*np.pi)/360
+        it = np.zeros((tap_number, ))
+        CC = np.zeros((
+            5,
+            tap_number,
+        ))
+        zeta = (phase * 2 * np.pi) / 360
         for k in range(tap_number):
             it[k] = (-k - turn_delay)
-        
-        phi = 2*np.pi*tune
-        cs = np.cos(phi*it)
-        sn = np.sin(phi*it)
-        
+
+        phi = 2 * np.pi * tune
+        cs = np.cos(phi * it)
+        sn = np.sin(phi * it)
+
         CC[0][:] = 1
         CC[1][:] = cs
         CC[2][:] = sn
-        CC[3][:] = it*sn
-        CC[4][:] = it*cs
-        
+        CC[3][:] = it * sn
+        CC[4][:] = it * cs
+
         TCC = np.transpose(CC)
         W = np.linalg.inv(CC.dot(TCC))
         D = W.dot(CC)
-        
-        FIR_coef = gain*(D[1][:]*np.cos(zeta) + D[2][:]*np.sin(zeta))
+
+        FIR_coef = gain * (D[1][:] * np.cos(zeta) + D[2][:] * np.sin(zeta))
         return FIR_coef
-    
+
     def plot_fir(self):
         """
         Plot the gain and the phase of the FIR filter.
@@ -193,26 +205,26 @@ class FIRDamper(Element):
             
         """
         tune = np.arange(0, 1, 0.0001)
-            
+
         H_FIR = 0
         for k in range(len(self.coef)):
-            H_FIR += self.coef[k]*np.exp(-1j*2*np.pi*(k)*tune)
-        latency = np.exp(-1j*2*np.pi*tune*self.turn_delay)
+            H_FIR += self.coef[k] * np.exp(-1j * 2 * np.pi * (k) * tune)
+        latency = np.exp(-1j * 2 * np.pi * tune * self.turn_delay)
         H_tot = H_FIR * latency
-        
+
         gain = np.abs(H_tot)
-        phase = np.angle(H_tot, deg = True)
-        
-        fig, [ax1, ax2] = plt.subplots(2,1)
+        phase = np.angle(H_tot, deg=True)
+
+        fig, [ax1, ax2] = plt.subplots(2, 1)
         ax1.plot(tune, gain)
         ax1.set_ylabel("Gain")
-        
+
         ax2.plot(tune, phase)
         ax2.set_xlabel("Tune")
         ax2.set_ylabel("Phase in degree")
-        
+
         return fig
-         
+
     def track(self, beam_or_bunch):
         """
         Tracking method.
@@ -236,7 +248,7 @@ class FIRDamper(Element):
                     self.track_sb(bunch, i)
         else:
             TypeError("beam_or_bunch must be a Beam or Bunch")
-            
+
     def init_beam_no_mpi(self, beam):
         """
         Change array sizes if Beam is used without mpi.
@@ -249,9 +261,9 @@ class FIRDamper(Element):
         """
         n_bunch = len(beam)
         self.pos = np.zeros((self.tap_number, n_bunch))
-        self.kick = np.zeros((self.turn_delay+1, n_bunch))
+        self.kick = np.zeros((self.turn_delay + 1, n_bunch))
         self.beam_no_mpi = True
-        
+
     def track_sb(self, bunch, bunch_number=0):
         """
         Core of the tracking method.
@@ -268,19 +280,19 @@ class FIRDamper(Element):
         self.pos[0, bunch_number] = bunch.mean[self.mean_idx]
         if self.bpm_error is not None:
             self.pos[0, bunch_number] += np.random.normal(0, self.bpm_error)
-            
+
         kick = 0
         for k in range(self.tap_number):
-            kick += self.coef[k]*self.pos[k, bunch_number]
-            
+            kick += self.coef[k] * self.pos[k, bunch_number]
+
         if self.max_kick is not None:
             if kick > self.max_kick:
                 kick = self.max_kick
-            elif kick < -1*self.max_kick:
-                kick = -1*self.max_kick
-            
+            elif kick < -1 * self.max_kick:
+                kick = -1 * self.max_kick
+
         self.kick[-1, bunch_number] = kick
         bunch[self.action] += self.kick[0, bunch_number]
-        
+
         self.pos[:, bunch_number] = np.roll(self.pos[:, bunch_number], 1)
-        self.kick[:, bunch_number] = np.roll(self.kick[:, bunch_number], -1)
\ No newline at end of file
+        self.kick[:, bunch_number] = np.roll(self.kick[:, bunch_number], -1)
diff --git a/mbtrack2/tracking/monitors/__init__.py b/mbtrack2/tracking/monitors/__init__.py
index b4688b7378e57fc37f7133981abd6259df1b2eda..4f1633f9a1569b8a14ea4297612e832ae34bf073 100644
--- a/mbtrack2/tracking/monitors/__init__.py
+++ b/mbtrack2/tracking/monitors/__init__.py
@@ -1,22 +1,26 @@
 # -*- coding: utf-8 -*-
-from mbtrack2.tracking.monitors.monitors import (Monitor, BunchMonitor, 
-                                                 PhaseSpaceMonitor,
-                                                 BeamMonitor,
-                                                 ProfileMonitor,
-                                                 WakePotentialMonitor,
-                                                 CavityMonitor,
-                                                 BunchSpectrumMonitor,
-                                                 BeamSpectrumMonitor)
-from mbtrack2.tracking.monitors.plotting import (plot_bunchdata, 
-                                                 plot_phasespacedata,
-                                                 plot_profiledata,
-                                                 plot_beamdata,
-                                                 plot_wakedata,
-                                                 plot_cavitydata,
-                                                 streak_beamdata,
-                                                 plot_bunchspectrum,
-                                                 streak_bunchspectrum,
-                                                 plot_beamspectrum,
-                                                 streak_beamspectrum)
-
-from mbtrack2.tracking.monitors.tools import (merge_files, copy_files)
\ No newline at end of file
+from mbtrack2.tracking.monitors.monitors import (
+    BeamMonitor,
+    BeamSpectrumMonitor,
+    BunchMonitor,
+    BunchSpectrumMonitor,
+    CavityMonitor,
+    Monitor,
+    PhaseSpaceMonitor,
+    ProfileMonitor,
+    WakePotentialMonitor,
+)
+from mbtrack2.tracking.monitors.plotting import (
+    plot_beamdata,
+    plot_beamspectrum,
+    plot_bunchdata,
+    plot_bunchspectrum,
+    plot_cavitydata,
+    plot_phasespacedata,
+    plot_profiledata,
+    plot_wakedata,
+    streak_beamdata,
+    streak_beamspectrum,
+    streak_bunchspectrum,
+)
+from mbtrack2.tracking.monitors.tools import copy_files, merge_files
diff --git a/mbtrack2/tracking/monitors/monitors.py b/mbtrack2/tracking/monitors/monitors.py
index f8641fe9985aabeaa56348c3e276e0db91bfd8b5..69c62ae91f832551be3a31c8dacdebea076151be 100644
--- a/mbtrack2/tracking/monitors/monitors.py
+++ b/mbtrack2/tracking/monitors/monitors.py
@@ -4,15 +4,18 @@ This module defines the different monitor class which are used to save data
 during tracking.
 """
 
-import numpy as np
-import h5py as hp
 import random
-from mbtrack2.tracking.element import Element
-from mbtrack2.tracking.particles import Bunch, Beam
-from mbtrack2.tracking.rf import CavityResonator
-from scipy.interpolate import interp1d
 from abc import ABCMeta
+
+import h5py as hp
+import numpy as np
 from scipy.fft import rfft, rfftfreq
+from scipy.interpolate import interp1d
+
+from mbtrack2.tracking.element import Element
+from mbtrack2.tracking.particles import Beam, Bunch
+from mbtrack2.tracking.rf import CavityResonator
+
 
 class Monitor(Element, metaclass=ABCMeta):
     """
@@ -48,10 +51,10 @@ class Monitor(Element, metaclass=ABCMeta):
     track_bunch_data(object_to_save)
         Track method to use when saving bunch data.
     """
-    
+
     _file_name_storage = []
     _file_storage = []
-    
+
     @property
     def file_name(self):
         """Common file where all monitors, Monitor subclass elements, write the
@@ -61,7 +64,7 @@ class Monitor(Element, metaclass=ABCMeta):
         except IndexError:
             print("The HDF5 file name for monitors is not set.")
             raise ValueError
-            
+
     @property
     def file(self):
         """Name of the HDF5 file where the data is stored."""
@@ -70,9 +73,16 @@ class Monitor(Element, metaclass=ABCMeta):
         except IndexError:
             print("The HDF5 file to store data is not set.")
             raise ValueError
-            
-    def monitor_init(self, group_name, save_every, buffer_size, total_size,
-                     dict_buffer, dict_file, file_name=None, mpi_mode=False,
+
+    def monitor_init(self,
+                     group_name,
+                     save_every,
+                     buffer_size,
+                     total_size,
+                     dict_buffer,
+                     dict_file,
+                     file_name=None,
+                     mpi_mode=False,
                      dict_dtype=None):
         """
         Method called to initialize Monitor subclass. 
@@ -113,7 +123,7 @@ class Monitor(Element, metaclass=ABCMeta):
             the dtype to use to save the values.
             If None, float is used for all attributes.
         """
-        
+
         # setup and open common file for all monitors
         if file_name is not None:
             if len(self._file_name_storage) == 0:
@@ -121,16 +131,20 @@ class Monitor(Element, metaclass=ABCMeta):
                 if len(self._file_storage) == 0:
                     if mpi_mode == True:
                         from mpi4py import MPI
-                        f = hp.File(self.file_name, "a", libver='earliest', 
-                             driver='mpio', comm=MPI.COMM_WORLD)
+                        f = hp.File(self.file_name,
+                                    "a",
+                                    libver='earliest',
+                                    driver='mpio',
+                                    comm=MPI.COMM_WORLD)
                     else:
                         f = hp.File(self.file_name, "a", libver='earliest')
                     self._file_storage.append(f)
                 else:
                     raise ValueError("File is already open.")
             else:
-                raise ValueError("File name for monitors is already attributed.")
-        
+                raise ValueError(
+                    "File name for monitors is already attributed.")
+
         self.group_name = group_name
         self.save_every = int(save_every)
         self.total_size = int(total_size)
@@ -140,48 +154,52 @@ class Monitor(Element, metaclass=ABCMeta):
         self.buffer_count = 0
         self.write_count = 0
         self.track_count = 0
-        
+
         # setup attribute buffers from values given in dict_buffer
         for key, value in dict_buffer.items():
             if dict_dtype == None:
-                self.__setattr__(key,np.zeros(value))
+                self.__setattr__(key, np.zeros(value))
             else:
-                self.__setattr__(key,np.zeros(value, dtype=dict_dtype[key]))
-        self.time = np.zeros((self.buffer_size,), dtype=int)
+                self.__setattr__(key, np.zeros(value, dtype=dict_dtype[key]))
+        self.time = np.zeros((self.buffer_size, ), dtype=int)
 
-        # create HDF5 groups and datasets to save data from group_name and 
+        # create HDF5 groups and datasets to save data from group_name and
         # dict_file
         self.g = self.file.require_group(self.group_name)
-        self.g.require_dataset("time", (self.total_size,), dtype=int)
+        self.g.require_dataset("time", (self.total_size, ), dtype=int)
         for key, value in dict_file.items():
             if dict_dtype == None:
                 self.g.require_dataset(key, value, dtype=float)
             else:
                 self.g.require_dataset(key, value, dtype=dict_dtype[key])
-        
+
         # create a dictionary which handle slices
         slice_dict = {}
         for key, value in dict_file.items():
             slice_dict[key] = []
-            for i in range(len(value)-1):
+            for i in range(len(value) - 1):
                 slice_dict[key].append(slice(None))
         self.slice_dict = slice_dict
-        
+
     def write(self):
         """Write data from buffer to the HDF5 file."""
-        
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
+
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
         for key, value in self.dict_buffer.items():
             slice_list = list(self.slice_dict[key])
-            slice_list.append(slice(self.write_count*self.buffer_size,
-                                    (self.write_count+1)*self.buffer_size))
+            slice_list.append(
+                slice(self.write_count * self.buffer_size,
+                      (self.write_count + 1) * self.buffer_size))
             slice_tuple = tuple(slice_list)
-            self.file[self.group_name][key][slice_tuple] = self.__getattribute__(key)
-        
+            self.file[
+                self.group_name][key][slice_tuple] = self.__getattribute__(key)
+
         self.file.flush()
         self.write_count += 1
-        
+
     def to_buffer(self, object_to_save):
         """
         Save data to buffer.
@@ -196,13 +214,14 @@ class Monitor(Element, metaclass=ABCMeta):
             slice_list = list(self.slice_dict[key])
             slice_list.append(self.buffer_count)
             slice_tuple = tuple(slice_list)
-            self.__getattribute__(key)[slice_tuple] = object_to_save.__getattribute__(key)
+            self.__getattribute__(
+                key)[slice_tuple] = object_to_save.__getattribute__(key)
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-            
+
     def close(self):
         """
         Close the HDF5 file shared by all Monitor subclass, must be called 
@@ -215,7 +234,7 @@ class Monitor(Element, metaclass=ABCMeta):
             Monitor._file_storage = []
         except ValueError:
             pass
-        
+
     def track_bunch_data(self, object_to_save, check_empty=False):
         """
         Track method to use when saving bunch data.
@@ -228,7 +247,7 @@ class Monitor(Element, metaclass=ABCMeta):
         """
         save = True
         if self.track_count % self.save_every == 0:
-            
+
             if isinstance(object_to_save, Beam):
                 if (object_to_save.mpi_switch == True):
                     if object_to_save.mpi.bunch_num == self.bunch_number:
@@ -239,15 +258,17 @@ class Monitor(Element, metaclass=ABCMeta):
                     bunch = object_to_save[self.bunch_number]
             elif isinstance(object_to_save, Bunch):
                 bunch = object_to_save
-            else:                            
-                raise TypeError("object_to_save should be a Beam or Bunch object.")
-            
+            else:
+                raise TypeError(
+                    "object_to_save should be a Beam or Bunch object.")
+
             if save:
                 if (check_empty == False) or (bunch.is_empty == False):
                     self.to_buffer(bunch)
-                
+
         self.track_count += 1
-            
+
+
 class BunchMonitor(Monitor):
     """
     Monitor a single bunch and save attributes 
@@ -281,24 +302,36 @@ class BunchMonitor(Monitor):
     track(object_to_save)
         Save data
     """
-    
-    def __init__(self, bunch_number, save_every, buffer_size, total_size, 
-                 file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 bunch_number,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 file_name=None,
+                 mpi_mode=False):
+
         self.bunch_number = bunch_number
         group_name = "BunchData_" + str(self.bunch_number)
-        dict_buffer = {"mean":(6, buffer_size), "std":(6, buffer_size),
-                     "emit":(3, buffer_size), "current":(buffer_size,),
-                     "cs_invariant":(3, buffer_size)}
-        dict_file = {"mean":(6, total_size), "std":(6, total_size),
-                     "emit":(3, total_size), "current":(total_size,),
-                     "cs_invariant":(3, total_size)}
+        dict_buffer = {
+            "mean": (6, buffer_size),
+            "std": (6, buffer_size),
+            "emit": (3, buffer_size),
+            "current": (buffer_size, ),
+            "cs_invariant": (3, buffer_size)
+        }
+        dict_file = {
+            "mean": (6, total_size),
+            "std": (6, total_size),
+            "emit": (3, total_size),
+            "current": (total_size, ),
+            "cs_invariant": (3, total_size)
+        }
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-                    
+
     def track(self, object_to_save):
         """
         Save data
@@ -306,9 +339,10 @@ class BunchMonitor(Monitor):
         Parameters
         ----------
         object_to_save : Bunch or Beam object
-        """        
+        """
         self.track_bunch_data(object_to_save)
-        
+
+
 class PhaseSpaceMonitor(Monitor):
     """
     Monitor a single bunch and save the full phase space.
@@ -344,23 +378,32 @@ class PhaseSpaceMonitor(Monitor):
     track(object_to_save)
         Save data
     """
-    
-    def __init__(self, bunch_number, mp_number, save_every, buffer_size, 
-                 total_size, file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 bunch_number,
+                 mp_number,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 file_name=None,
+                 mpi_mode=False):
+
         self.bunch_number = bunch_number
         self.mp_number = int(mp_number)
         group_name = "PhaseSpaceData_" + str(self.bunch_number)
-        dict_buffer = {"particles":(self.mp_number, 6, buffer_size), 
-                       "alive":(self.mp_number, buffer_size)}
-        dict_file = {"particles":(self.mp_number, 6, total_size),
-                     "alive":(self.mp_number, total_size)}
+        dict_buffer = {
+            "particles": (self.mp_number, 6, buffer_size),
+            "alive": (self.mp_number, buffer_size)
+        }
+        dict_file = {
+            "particles": (self.mp_number, 6, total_size),
+            "alive": (self.mp_number, total_size)
+        }
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-                    
+
     def track(self, object_to_save):
         """
         Save data
@@ -368,9 +411,9 @@ class PhaseSpaceMonitor(Monitor):
         Parameters
         ----------
         object_to_save : Bunch or Beam object
-        """        
+        """
         self.track_bunch_data(object_to_save)
-        
+
     def to_buffer(self, bunch):
         """
         Save data to buffer.
@@ -380,7 +423,7 @@ class PhaseSpaceMonitor(Monitor):
         bunch : Bunch object
         """
         self.time[self.buffer_count] = self.track_count
-        
+
         if len(bunch.alive) != self.mp_number:
             index = np.arange(len(bunch.alive))
             samples_meta = random.sample(list(index), self.mp_number)
@@ -390,16 +433,16 @@ class PhaseSpaceMonitor(Monitor):
 
         self.alive[:, self.buffer_count] = bunch.alive[samples]
         for i, dim in enumerate(bunch):
-            self.particles[:, i, self.buffer_count] = bunch.particles[dim][samples]
-        
+            self.particles[:, i,
+                           self.buffer_count] = bunch.particles[dim][samples]
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-        
 
-            
+
 class BeamMonitor(Monitor):
     """
     Monitor the full beam and save each bunch attributes (mean, std, emit and 
@@ -433,25 +476,33 @@ class BeamMonitor(Monitor):
     track(beam)
         Save data    
     """
-    
-    def __init__(self, h, save_every, buffer_size, total_size, file_name=None, 
+    def __init__(self,
+                 h,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 file_name=None,
                  mpi_mode=False):
-        
+
         group_name = "Beam"
-        dict_buffer = {"mean" : (6, h, buffer_size), 
-                       "std" : (6, h, buffer_size),
-                       "emit" : (3, h, buffer_size),
-                       "current" : (h, buffer_size),
-                       "cs_invariant" : (3, h, buffer_size)}
-        dict_file = {"mean" : (6, h, total_size), 
-                       "std" : (6, h, total_size),
-                       "emit" : (3, h, total_size),
-                       "current" : (h, total_size),
-                       "cs_invariant" : (3, h, total_size)}
-        
+        dict_buffer = {
+            "mean": (6, h, buffer_size),
+            "std": (6, h, buffer_size),
+            "emit": (3, h, buffer_size),
+            "current": (h, buffer_size),
+            "cs_invariant": (3, h, buffer_size)
+        }
+        dict_file = {
+            "mean": (6, h, total_size),
+            "std": (6, h, total_size),
+            "emit": (3, h, total_size),
+            "current": (h, total_size),
+            "cs_invariant": (3, h, total_size)
+        }
+
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-                    
+
     def track(self, beam):
         """
         Save data
@@ -459,15 +510,15 @@ class BeamMonitor(Monitor):
         Parameters
         ----------
         beam : Beam object
-        """     
+        """
         if self.track_count % self.save_every == 0:
             if (beam.mpi_switch == True):
                 self.to_buffer(beam[beam.mpi.bunch_num], beam.mpi.bunch_num)
             else:
                 self.to_buffer_no_mpi(beam)
-                    
+
         self.track_count += 1
-        
+
     def to_buffer(self, bunch, bunch_num):
         """
         Save data to buffer, if mpi is being used.
@@ -477,20 +528,20 @@ class BeamMonitor(Monitor):
         bunch : Bunch object
         bunch_num : int
         """
-        
+
         self.time[self.buffer_count] = self.track_count
         self.mean[:, bunch_num, self.buffer_count] = bunch.mean
         self.std[:, bunch_num, self.buffer_count] = bunch.std
         self.emit[:, bunch_num, self.buffer_count] = bunch.emit
         self.current[bunch_num, self.buffer_count] = bunch.current
         self.cs_invariant[:, bunch_num, self.buffer_count] = bunch.cs_invariant
-        
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write(bunch_num)
             self.buffer_count = 0
-            
+
     def to_buffer_no_mpi(self, beam):
         """
         Save data to buffer, if mpi is not being used.
@@ -499,16 +550,16 @@ class BeamMonitor(Monitor):
         ----------
         beam : Beam object
         """
-              
+
         self.time[self.buffer_count] = self.track_count
         self.mean[:, :, self.buffer_count] = beam.bunch_mean
         self.std[:, :, self.buffer_count] = beam.bunch_std
         self.emit[:, :, self.buffer_count] = beam.bunch_emit
         self.current[:, self.buffer_count] = beam.bunch_current
         self.cs_invariant[:, :, self.buffer_count] = beam.bunch_cs
-        
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write_no_mpi()
             self.buffer_count = 0
@@ -521,64 +572,80 @@ class BeamMonitor(Monitor):
         ----------
         bunch_num : int
         """
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
-    
-        self.file[self.group_name]["mean"][:, bunch_num, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.mean[:, bunch_num, :]
-                 
-        self.file[self.group_name]["std"][:, bunch_num, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.std[:, bunch_num, :]
-
-        self.file[self.group_name]["emit"][:, bunch_num, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.emit[:, bunch_num, :]
-
-        self.file[self.group_name]["current"][bunch_num, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.current[bunch_num, :]
-        
-        self.file[self.group_name]["cs_invariant"][:, bunch_num, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.cs_invariant[:, bunch_num, :]
-                 
-        self.file.flush() 
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
+
+        self.file[self.group_name][
+            "mean"][:, bunch_num, self.write_count *
+                    self.buffer_size:(self.write_count + 1) *
+                    self.buffer_size] = self.mean[:, bunch_num, :]
+
+        self.file[
+            self.group_name]["std"][:, bunch_num, self.write_count *
+                                    self.buffer_size:(self.write_count + 1) *
+                                    self.buffer_size] = self.std[:,
+                                                                 bunch_num, :]
+
+        self.file[self.group_name][
+            "emit"][:, bunch_num, self.write_count *
+                    self.buffer_size:(self.write_count + 1) *
+                    self.buffer_size] = self.emit[:, bunch_num, :]
+
+        self.file[self.group_name]["current"][
+            bunch_num,
+            self.write_count * self.buffer_size:(self.write_count + 1) *
+            self.buffer_size] = self.current[bunch_num, :]
+
+        self.file[self.group_name][
+            "cs_invariant"][:, bunch_num, self.write_count *
+                            self.buffer_size:(self.write_count + 1) *
+                            self.buffer_size] = self.cs_invariant[:,
+                                                                  bunch_num, :]
+
+        self.file.flush()
         self.write_count += 1
-        
+
     def write_no_mpi(self):
         """
         Write data from buffer to the HDF5 file, if mpi is not being used.
         """
-        
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
-    
-        self.file[self.group_name]["mean"][:, :, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.mean
-                 
-        self.file[self.group_name]["std"][:, :, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.std
-
-        self.file[self.group_name]["emit"][:, :, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.emit
-
-        self.file[self.group_name]["current"][:, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.current
-        
-        self.file[self.group_name]["cs_invariant"][:, :, 
-                 self.write_count*self.buffer_size:(self.write_count+1) * 
-                 self.buffer_size] = self.cs_invariant
-                 
-        self.file.flush() 
+
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
+
+        self.file[self.group_name]["mean"][:, :, self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.mean
+
+        self.file[self.group_name]["std"][:, :, self.write_count *
+                                          self.buffer_size:(self.write_count +
+                                                            1) *
+                                          self.buffer_size] = self.std
+
+        self.file[self.group_name]["emit"][:, :, self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.emit
+
+        self.file[self.group_name]["current"][:, self.write_count *
+                                              self.buffer_size:
+                                              (self.write_count + 1) *
+                                              self.buffer_size] = self.current
+
+        self.file[self.group_name][
+            "cs_invariant"][:, :, self.write_count *
+                            self.buffer_size:(self.write_count + 1) *
+                            self.buffer_size] = self.cs_invariant
+
+        self.file.flush()
         self.write_count += 1
 
-        
+
 class ProfileMonitor(Monitor):
     """
     Monitor a single bunch and save bunch profiles.
@@ -615,37 +682,45 @@ class ProfileMonitor(Monitor):
     track(object_to_save)
         Save data.
     """
-    
-    def __init__(self, bunch_number, save_every, buffer_size, total_size, 
-                 dimensions="tau", n_bin=75, file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 bunch_number,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 dimensions="tau",
+                 n_bin=75,
+                 file_name=None,
+                 mpi_mode=False):
+
         self.bunch_number = bunch_number
         group_name = "ProfileData_" + str(self.bunch_number)
-        
+
         if isinstance(dimensions, str):
             self.dimensions = [dimensions]
         else:
             self.dimensions = dimensions
-            
+
         if isinstance(n_bin, int):
-            self.n_bin = np.ones((len(self.dimensions),), dtype=int)*n_bin
+            self.n_bin = np.ones((len(self.dimensions), ), dtype=int) * n_bin
         else:
             self.n_bin = n_bin
-        
+
         dict_buffer = {}
         dict_file = {}
         for index, dim in enumerate(self.dimensions):
-            dict_buffer.update({dim : (self.n_bin[index] - 1, buffer_size)})
-            dict_buffer.update({dim + "_bin" : (self.n_bin[index] - 1, buffer_size)})
-            dict_file.update({dim : (self.n_bin[index] - 1, total_size)})
-            dict_file.update({dim + "_bin" : (self.n_bin[index] - 1, total_size)})
+            dict_buffer.update({dim: (self.n_bin[index] - 1, buffer_size)})
+            dict_buffer.update(
+                {dim + "_bin": (self.n_bin[index] - 1, buffer_size)})
+            dict_file.update({dim: (self.n_bin[index] - 1, total_size)})
+            dict_file.update(
+                {dim + "_bin": (self.n_bin[index] - 1, total_size)})
 
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-        
+
     def to_buffer(self, bunch):
         """
         Save data to buffer.
@@ -657,32 +732,38 @@ class ProfileMonitor(Monitor):
 
         self.time[self.buffer_count] = self.track_count
         for index, dim in enumerate(self.dimensions):
-            bins, sorted_index, profile, center = bunch.binning(dim, self.n_bin[index])
+            bins, sorted_index, profile, center = bunch.binning(
+                dim, self.n_bin[index])
             self.__getattribute__(dim + "_bin")[:, self.buffer_count] = center
             self.__getattribute__(dim)[:, self.buffer_count] = profile
-        
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-            
+
     def write(self):
         """Write data from buffer to the HDF5 file."""
-        
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
+
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
 
         for dim in self.dimensions:
-            self.file[self.group_name][dim][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__(dim)
-            self.file[self.group_name][dim + "_bin"][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__(dim + "_bin")
-            
+            self.file[self.group_name][
+                dim][:, self.write_count *
+                     self.buffer_size:(self.write_count + 1) *
+                     self.buffer_size] = self.__getattribute__(dim)
+            self.file[self.group_name][
+                dim + "_bin"][:, self.write_count *
+                              self.buffer_size:(self.write_count + 1) *
+                              self.buffer_size] = self.__getattribute__(dim +
+                                                                        "_bin")
+
         self.write_count += 1
-                    
+
     def track(self, object_to_save):
         """
         Save data.
@@ -690,9 +771,10 @@ class ProfileMonitor(Monitor):
         Parameters
         ----------
         object_to_save : Bunch or Beam object
-        """        
+        """
         self.track_bunch_data(object_to_save, check_empty=True)
-        
+
+
 class WakePotentialMonitor(Monitor):
     """
     Monitor the wake potential from a single bunch and save attributes (tau, 
@@ -731,39 +813,46 @@ class WakePotentialMonitor(Monitor):
     track(object_to_save, wake_potential_to_save)
         Save data.
     """
-    
-    def __init__(self, bunch_number, wake_types, n_bin, save_every, 
-                 buffer_size, total_size, file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 bunch_number,
+                 wake_types,
+                 n_bin,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 file_name=None,
+                 mpi_mode=False):
+
         self.bunch_number = bunch_number
         group_name = "WakePotentialData_" + str(self.bunch_number)
-        
+
         if isinstance(wake_types, str):
             self.wake_types = [wake_types]
         else:
             self.wake_types = wake_types
-            
-        self.n_bin = n_bin*2
-        
+
+        self.n_bin = n_bin * 2
+
         dict_buffer = {}
         dict_file = {}
         for index, dim in enumerate(self.wake_types):
-            dict_buffer.update({"tau_" + dim : (self.n_bin, buffer_size)})
-            dict_file.update({"tau_" + dim : (self.n_bin, total_size)})
-            dict_buffer.update({"profile_" + dim : (self.n_bin, buffer_size)})
-            dict_file.update({"profile_" + dim : (self.n_bin, total_size)})
-            dict_buffer.update({dim : (self.n_bin, buffer_size)})
-            dict_file.update({dim : (self.n_bin, total_size)})
+            dict_buffer.update({"tau_" + dim: (self.n_bin, buffer_size)})
+            dict_file.update({"tau_" + dim: (self.n_bin, total_size)})
+            dict_buffer.update({"profile_" + dim: (self.n_bin, buffer_size)})
+            dict_file.update({"profile_" + dim: (self.n_bin, total_size)})
+            dict_buffer.update({dim: (self.n_bin, buffer_size)})
+            dict_file.update({dim: (self.n_bin, total_size)})
             if dim == "Wxdip" or dim == "Wydip":
-                dict_buffer.update({"dipole_" + dim : (self.n_bin, buffer_size)})
-                dict_file.update({"dipole_" + dim : (self.n_bin, total_size)})
+                dict_buffer.update(
+                    {"dipole_" + dim: (self.n_bin, buffer_size)})
+                dict_file.update({"dipole_" + dim: (self.n_bin, total_size)})
 
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-        
+
     def to_buffer(self, wp):
         """
         Save data to buffer.
@@ -782,52 +871,67 @@ class WakePotentialMonitor(Monitor):
                 dipole0 = wp.__getattribute__("dipole_x")
             if dim == "Wydip":
                 dipole0 = wp.__getattribute__("dipole_y")
-            
+
             tau = np.linspace(tau0[0], tau0[-1], self.n_bin)
-            f = interp1d(tau0, WP0, fill_value = 0, bounds_error = False)
+            f = interp1d(tau0, WP0, fill_value=0, bounds_error=False)
             WP = f(tau)
-            g = interp1d(tau0, profile0, fill_value = 0, bounds_error = False)
+            g = interp1d(tau0, profile0, fill_value=0, bounds_error=False)
             profile = g(tau)
             if dim == "Wxdip" or dim == "Wydip":
-                h = interp1d(tau0, dipole0, fill_value = 0, bounds_error = False)
+                h = interp1d(tau0, dipole0, fill_value=0, bounds_error=False)
                 dipole = h(tau)
-            
-            self.__getattribute__("tau_" + dim)[:, self.buffer_count] = tau + wp.tau_mean
-            self.__getattribute__("profile_" + dim)[:, self.buffer_count] = profile
+
+            self.__getattribute__("tau_" +
+                                  dim)[:,
+                                       self.buffer_count] = tau + wp.tau_mean
+            self.__getattribute__("profile_" +
+                                  dim)[:, self.buffer_count] = profile
             self.__getattribute__(dim)[:, self.buffer_count] = WP
             if dim == "Wxdip" or dim == "Wydip":
-                self.__getattribute__("dipole_" + dim)[:, self.buffer_count] = dipole
-            
+                self.__getattribute__("dipole_" +
+                                      dim)[:, self.buffer_count] = dipole
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-            
+
     def write(self):
         """Write data from buffer to the HDF5 file."""
-        
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
-        
+
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
+
         for dim in self.wake_types:
-            self.file[self.group_name]["tau_" + dim][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__("tau_" + dim)
-            self.file[self.group_name]["profile_" + dim][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__("profile_" + dim)
-            self.file[self.group_name][dim][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__(dim)
+            self.file[self.group_name][
+                "tau_" +
+                dim][:, self.write_count *
+                     self.buffer_size:(self.write_count + 1) *
+                     self.buffer_size] = self.__getattribute__("tau_" + dim)
+            self.file[self.group_name][
+                "profile_" +
+                dim][:, self.write_count *
+                     self.buffer_size:(self.write_count + 1) *
+                     self.buffer_size] = self.__getattribute__("profile_" +
+                                                               dim)
+            self.file[self.group_name][
+                dim][:, self.write_count *
+                     self.buffer_size:(self.write_count + 1) *
+                     self.buffer_size] = self.__getattribute__(dim)
             if dim == "Wxdip" or dim == "Wydip":
-                self.file[self.group_name]["dipole_" + dim][:, 
-                    self.write_count * self.buffer_size:(self.write_count+1) * 
-                    self.buffer_size] = self.__getattribute__("dipole_" + dim)
-            
+                self.file[self.group_name][
+                    "dipole_" +
+                    dim][:, self.write_count *
+                         self.buffer_size:(self.write_count + 1) *
+                         self.buffer_size] = self.__getattribute__("dipole_" +
+                                                                   dim)
+
         self.file.flush()
         self.write_count += 1
-                    
+
     def track(self, object_to_save, wake_potential_to_save):
         """
         Save data.
@@ -849,13 +953,14 @@ class WakePotentialMonitor(Monitor):
                                           "with MPI mode.")
         elif isinstance(object_to_save, Bunch):
             save = True
-        else:                            
+        else:
             raise TypeError("object_to_save should be a Beam or Bunch object.")
-            
+
         if save and (self.track_count % self.save_every == 0):
             self.to_buffer(wake_potential_to_save)
         self.track_count += 1
-    
+
+
 class BunchSpectrumMonitor(Monitor):
     """
     Monitor the coherent and incoherent bunch spectrums. 
@@ -918,77 +1023,91 @@ class BunchSpectrumMonitor(Monitor):
         Save spectrum data.
     
     """
-    
-    def __init__(self, ring, bunch_number, mp_number, sample_size, save_every, 
-                 buffer_size, total_size, dim="all", n_fft=None, 
-                 file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 ring,
+                 bunch_number,
+                 mp_number,
+                 sample_size,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 dim="all",
+                 n_fft=None,
+                 file_name=None,
+                 mpi_mode=False):
+
         if n_fft is None:
             self.n_fft = int(save_every)
         else:
             self.n_fft = int(n_fft)
-            
+
         self.sample_size = int(sample_size)
-        self.store_dict = {"x":0,"y":1,"tau":2}
+        self.store_dict = {"x": 0, "y": 1, "tau": 2}
 
         if dim == "all":
-            self.track_dict = {"x":0,"y":1,"tau":2}
+            self.track_dict = {"x": 0, "y": 1, "tau": 2}
             self.mean_index = [True, False, True, False, True, False]
         elif dim == "tau":
-            self.track_dict = {"tau":0}
+            self.track_dict = {"tau": 0}
             self.mean_index = [False, False, False, False, True, False]
         elif dim == "x":
-            self.track_dict = {"x":0}
+            self.track_dict = {"x": 0}
             self.mean_index = [True, False, False, False, False, False]
         elif dim == "y":
-            self.track_dict = {"y":0}
+            self.track_dict = {"y": 0}
             self.mean_index = [False, False, True, False, False, False]
         elif dim == "xy" or dim == "yx":
-            self.track_dict = {"x":0,"y":1}
+            self.track_dict = {"x": 0, "y": 1}
             self.mean_index = [True, False, True, False, False, False]
         elif dim == "xtau" or dim == "taux":
-            self.track_dict = {"x":0,"tau":1}
+            self.track_dict = {"x": 0, "tau": 1}
             self.mean_index = [True, False, False, False, True, False]
         elif dim == "ytau" or dim == "tauy":
-            self.track_dict = {"y":0,"tau":1}
+            self.track_dict = {"y": 0, "tau": 1}
             self.mean_index = [False, False, True, False, True, False]
         else:
             raise ValueError("dim is not correct.")
-        
+
         self.size_list = len(self.track_dict)
-        
+
         self.ring = ring
         self.bunch_number = bunch_number
         group_name = "BunchSpectrum_" + str(self.bunch_number)
-        
-        dict_buffer = {"incoherent":(3, self.n_fft//2+1, buffer_size),
-                       "coherent":(3, self.n_fft//2+1, buffer_size),
-                       "mean_incoherent":(3,buffer_size),
-                       "std_incoherent":(3,buffer_size)}
-        dict_file = {"incoherent":(3, self.n_fft//2+1, total_size),
-                        "coherent":(3, self.n_fft//2+1, total_size),
-                        "mean_incoherent":(3,total_size),
-                        "std_incoherent":(3,total_size)}
-        
+
+        dict_buffer = {
+            "incoherent": (3, self.n_fft // 2 + 1, buffer_size),
+            "coherent": (3, self.n_fft // 2 + 1, buffer_size),
+            "mean_incoherent": (3, buffer_size),
+            "std_incoherent": (3, buffer_size)
+        }
+        dict_file = {
+            "incoherent": (3, self.n_fft // 2 + 1, total_size),
+            "coherent": (3, self.n_fft // 2 + 1, total_size),
+            "mean_incoherent": (3, total_size),
+            "std_incoherent": (3, total_size)
+        }
+
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-        
+
         self.save_count = 0
-        
-        self.positions = np.zeros((self.size_list, self.sample_size, self.save_every+1))
-        self.mean = np.zeros((self.size_list, self.save_every+1))
-        
+
+        self.positions = np.zeros(
+            (self.size_list, self.sample_size, self.save_every + 1))
+        self.mean = np.zeros((self.size_list, self.save_every + 1))
+
         index = np.arange(0, int(mp_number))
-        self.index_sample = sorted(random.sample(list(index), self.sample_size))
-                
-        self.incoherent = np.zeros((3, self.n_fft//2+1, self.buffer_size))
-        self.coherent = np.zeros((3, self.n_fft//2+1, self.buffer_size))
-        
-        self.file[self.group_name].create_dataset(
-            "freq", data=self.frequency_samples)
+        self.index_sample = sorted(random.sample(list(index),
+                                                 self.sample_size))
+
+        self.incoherent = np.zeros((3, self.n_fft // 2 + 1, self.buffer_size))
+        self.coherent = np.zeros((3, self.n_fft // 2 + 1, self.buffer_size))
+
+        self.file[self.group_name].create_dataset("freq",
+                                                  data=self.frequency_samples)
 
     @property
     def fft_resolution(self):
@@ -997,8 +1116,8 @@ class BunchSpectrumMonitor(Monitor):
         
         It is defined as the sampling frequency over the number of samples.
         """
-        return self.ring.f0/self.n_fft
-    
+        return self.ring.f0 / self.n_fft
+
     @property
     def signal_resolution(self):
         """
@@ -1006,15 +1125,15 @@ class BunchSpectrumMonitor(Monitor):
         
         It is defined as the inverse of the signal length.
         """
-        return 1/(self.ring.T0*self.save_every)
-    
+        return 1 / (self.ring.T0 * self.save_every)
+
     @property
     def frequency_samples(self):
         """
         Return the fft frequency samples in [Hz].
         """
-        return rfftfreq(self.n_fft, self.ring.T0)    
-        
+        return rfftfreq(self.n_fft, self.ring.T0)
+
     def track(self, object_to_save):
         """
         Save spectrum data.
@@ -1037,69 +1156,84 @@ class BunchSpectrumMonitor(Monitor):
             bunch = object_to_save
         else:
             raise TypeError("object_to_save should be a Beam or Bunch object.")
-        
+
         if save:
             try:
                 for key, value in self.track_dict.items():
-                    self.positions[value, :, self.save_count] = bunch[key][self.index_sample]
+                    self.positions[value, :, self.save_count] = bunch[key][
+                        self.index_sample]
             except IndexError:
                 self.positions[value, :, self.save_count] = np.nan
-            
+
             self.mean[:, self.save_count] = bunch.mean[self.mean_index]
-            
+
             self.save_count += 1
-            
+
             if self.track_count > 0 and self.track_count % self.save_every == 0:
                 self.to_buffer(bunch)
                 self.save_count = 0
-    
+
             self.track_count += 1
-        
+
     def to_buffer(self, bunch):
         """
         A method to hold saved data before writing it to the output file.
 
         """
-        
+
         self.time[self.buffer_count] = self.track_count
-        
+
         for key, value in self.track_dict.items():
-            incoherent, mean_incoherent, std_incoherent = self.get_incoherent_spectrum(self.positions[value,:,:])
-            self.incoherent[self.store_dict[key],:,self.buffer_count] = incoherent
-            self.mean_incoherent[self.store_dict[key],self.buffer_count] = mean_incoherent
-            self.std_incoherent[self.store_dict[key],self.buffer_count] = std_incoherent
-            self.coherent[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.mean[value])
-        
+            incoherent, mean_incoherent, std_incoherent = self.get_incoherent_spectrum(
+                self.positions[value, :, :])
+            self.incoherent[self.store_dict[key], :,
+                            self.buffer_count] = incoherent
+            self.mean_incoherent[self.store_dict[key],
+                                 self.buffer_count] = mean_incoherent
+            self.std_incoherent[self.store_dict[key],
+                                self.buffer_count] = std_incoherent
+            self.coherent[self.store_dict[key], :,
+                          self.buffer_count] = self.get_coherent_spectrum(
+                              self.mean[value])
+
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-            
+
     def write(self):
         """
         Write data from buffer to output file.
 
         """
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
-
-        self.file[self.group_name]["incoherent"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.incoherent
-        self.file[self.group_name]["mean_incoherent"][:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.mean_incoherent
-        self.file[self.group_name]["std_incoherent"][:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.std_incoherent
-        self.file[self.group_name]["coherent"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.coherent
-            
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
+
+        self.file[
+            self.group_name]["incoherent"][:, :, self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.incoherent
+        self.file[self.group_name][
+            "mean_incoherent"][:, self.write_count *
+                               self.buffer_size:(self.write_count + 1) *
+                               self.buffer_size] = self.mean_incoherent
+        self.file[self.group_name][
+            "std_incoherent"][:, self.write_count *
+                              self.buffer_size:(self.write_count + 1) *
+                              self.buffer_size] = self.std_incoherent
+        self.file[
+            self.group_name]["coherent"][:, :, self.write_count *
+                                         self.buffer_size:(self.write_count +
+                                                           1) *
+                                         self.buffer_size] = self.coherent
+
         self.file.flush()
         self.write_count += 1
-    
+
     def get_incoherent_spectrum(self, positions):
         """
         Compute the incoherent spectrum i.e. the average of the absolute value 
@@ -1119,14 +1253,14 @@ class BunchSpectrumMonitor(Monitor):
         """
         fourier = rfft(positions, n=self.n_fft)
         fourier_abs = np.abs(fourier)
-        max_array = np.argmax(fourier_abs,axis=1)
+        max_array = np.argmax(fourier_abs, axis=1)
         freq_array = self.frequency_samples[max_array]
         mean_incoherent = np.mean(freq_array)
         std_incoherent = np.std(freq_array)
         incoherent = np.mean(fourier_abs, axis=0)
-        
+
         return incoherent, mean_incoherent, std_incoherent
-    
+
     def get_coherent_spectrum(self, mean):
         """
         Compute the coherent spectrum i.e. the absolute value of the FT of the
@@ -1139,9 +1273,10 @@ class BunchSpectrumMonitor(Monitor):
 
         """
         coherent = np.abs(rfft(mean, n=self.n_fft))
-        
+
         return coherent
-    
+
+
 class BeamSpectrumMonitor(Monitor):
     """
     Monitor coherent beam spectrum. 
@@ -1197,63 +1332,69 @@ class BeamSpectrumMonitor(Monitor):
         Save spectrum data.
     
     """
-    
-    def __init__(self, ring, save_every, buffer_size, total_size, dim="all", 
-                 n_fft=None, file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 ring,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 dim="all",
+                 n_fft=None,
+                 file_name=None,
+                 mpi_mode=False):
+
         if n_fft is None:
             self.n_fft = int(save_every)
         else:
             self.n_fft = int(n_fft)
-            
-        self.store_dict = {"x":0,"y":1,"tau":2}
+
+        self.store_dict = {"x": 0, "y": 1, "tau": 2}
 
         if dim == "all":
-            self.track_dict = {"x":0,"y":1,"tau":2}
+            self.track_dict = {"x": 0, "y": 1, "tau": 2}
             self.mean_index = [True, False, True, False, True, False]
         elif dim == "tau":
-            self.track_dict = {"tau":0}
+            self.track_dict = {"tau": 0}
             self.mean_index = [False, False, False, False, True, False]
         elif dim == "x":
-            self.track_dict = {"x":0}
+            self.track_dict = {"x": 0}
             self.mean_index = [True, False, False, False, False, False]
         elif dim == "y":
-            self.track_dict = {"y":0}
+            self.track_dict = {"y": 0}
             self.mean_index = [False, False, True, False, False, False]
         elif dim == "xy" or dim == "yx":
-            self.track_dict = {"x":0,"y":1}
+            self.track_dict = {"x": 0, "y": 1}
             self.mean_index = [True, False, True, False, False, False]
         elif dim == "xtau" or dim == "taux":
-            self.track_dict = {"x":0,"tau":1}
+            self.track_dict = {"x": 0, "tau": 1}
             self.mean_index = [True, False, False, False, True, False]
         elif dim == "ytau" or dim == "tauy":
-            self.track_dict = {"y":0,"tau":1}
+            self.track_dict = {"y": 0, "tau": 1}
             self.mean_index = [False, False, True, False, True, False]
         else:
             raise ValueError("dim is not correct.")
-        
+
         self.size_list = len(self.track_dict)
-        
+
         self.ring = ring
         group_name = "BeamSpectrum"
-        
-        dict_buffer = {"coherent":(3, self.n_fft//2+1, buffer_size)}
-        dict_file = {"coherent":(3, self.n_fft//2+1, total_size)}
-        
+
+        dict_buffer = {"coherent": (3, self.n_fft // 2 + 1, buffer_size)}
+        dict_file = {"coherent": (3, self.n_fft // 2 + 1, total_size)}
+
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-        
+
         self.save_count = 0
-        
+
         self.mean = np.zeros((self.size_list, ring.h, self.save_every))
-        self.coherent = np.zeros((3, self.n_fft//2+1, self.buffer_size))
-        
-        self.file[self.group_name].create_dataset(
-            "freq", data=self.frequency_samples)
-        
+        self.coherent = np.zeros((3, self.n_fft // 2 + 1, self.buffer_size))
+
+        self.file[self.group_name].create_dataset("freq",
+                                                  data=self.frequency_samples)
+
     @property
     def fft_resolution(self):
         """
@@ -1261,8 +1402,8 @@ class BeamSpectrumMonitor(Monitor):
         
         It is defined as the sampling frequency over the number of samples.
         """
-        return self.ring.f1/self.n_fft
-    
+        return self.ring.f1 / self.n_fft
+
     @property
     def signal_resolution(self):
         """
@@ -1270,15 +1411,15 @@ class BeamSpectrumMonitor(Monitor):
         
         It is defined as the inverse of the signal length.
         """
-        return 1/(self.ring.T0*self.save_every)
-    
+        return 1 / (self.ring.T0 * self.save_every)
+
     @property
     def frequency_samples(self):
         """
         Return the fft frequency samples in [Hz].
         """
         return rfftfreq(self.n_fft, self.ring.T1)
-        
+
     def track(self, beam):
         """
         Save mean data.
@@ -1291,16 +1432,18 @@ class BeamSpectrumMonitor(Monitor):
         if (beam.mpi_switch == True):
             bunch_num = beam.mpi.bunch_num
             bunch = beam[bunch_num]
-            self.mean[:, bunch_num, self.save_count] = bunch.mean[self.mean_index]
+            self.mean[:, bunch_num,
+                      self.save_count] = bunch.mean[self.mean_index]
         else:
-            self.mean[:, :, self.save_count] = beam.bunch_mean[self.mean_index,:]
-            
+            self.mean[:, :,
+                      self.save_count] = beam.bunch_mean[self.mean_index, :]
+
         self.save_count += 1
-            
+
         if self.save_count == self.save_every:
             self.to_buffer(beam)
             self.save_count = 0
-        
+
         self.track_count += 1
 
     def to_buffer(self, beam):
@@ -1308,9 +1451,9 @@ class BeamSpectrumMonitor(Monitor):
         A method to hold saved data before writing it to the output file.
 
         """
-        
+
         self.time[self.buffer_count] = self.track_count
-        
+
         for key, value in self.track_dict.items():
             if (beam.mpi_switch == True):
                 data_core = self.mean[value, beam.mpi.bunch_num, :]
@@ -1318,28 +1461,33 @@ class BeamSpectrumMonitor(Monitor):
                 data = np.reshape(full_data, (-1), 'F')
             else:
                 data = np.reshape(self.mean[value, :, :], (-1), 'F')
-            self.coherent[self.store_dict[key],:,self.buffer_count] = self.get_beam_spectrum(data)        
+            self.coherent[self.store_dict[key], :,
+                          self.buffer_count] = self.get_beam_spectrum(data)
         self.buffer_count += 1
-        
+
         if self.buffer_count == self.buffer_size:
             self.write()
             self.buffer_count = 0
-            
+
     def write(self):
         """
         Write data from buffer to output file.
 
         """
-        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
-                    self.write_count+1)*self.buffer_size] = self.time
+        self.file[self.group_name]["time"][self.write_count *
+                                           self.buffer_size:(self.write_count +
+                                                             1) *
+                                           self.buffer_size] = self.time
+
+        self.file[
+            self.group_name]["coherent"][:, :, self.write_count *
+                                         self.buffer_size:(self.write_count +
+                                                           1) *
+                                         self.buffer_size] = self.coherent
 
-        self.file[self.group_name]["coherent"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.coherent
-            
         self.file.flush()
         self.write_count += 1
-    
+
     def get_beam_spectrum(self, mean):
         """
         Compute the beam coherent spectrum i.e. the absolute value of the FT 
@@ -1352,9 +1500,10 @@ class BeamSpectrumMonitor(Monitor):
 
         """
         coherent = np.abs(rfft(mean, n=self.n_fft))
-        
+
         return coherent
-        
+
+
 class CavityMonitor(Monitor):
     """
     Monitor a CavityResonator object and save attributes.
@@ -1388,64 +1537,99 @@ class CavityMonitor(Monitor):
     track(beam, cavity)
         Save data
     """
-    
-    def __init__(self, cavity_name, ring, save_every, buffer_size, total_size, 
-                 file_name=None, mpi_mode=False):
-        
+    def __init__(self,
+                 cavity_name,
+                 ring,
+                 save_every,
+                 buffer_size,
+                 total_size,
+                 file_name=None,
+                 mpi_mode=False):
+
         self.cavity_name = cavity_name
         self.ring = ring
-        
+
         group_name = cavity_name
-        dict_buffer = {"cavity_phasor_record":(ring.h, buffer_size,),
-                       "beam_phasor_record":(ring.h, buffer_size,),
-                       "generator_phasor_record":(ring.h, buffer_size,),
-                       "ig_phasor_record":(ring.h, buffer_size,),
-                       "detune":(buffer_size,),
-                       "psi":(buffer_size,),
-                       "Vg":(buffer_size,),
-                       "theta_g":(buffer_size,),
-                       "Pg":(buffer_size,),
-                       "Rs":(buffer_size,),
-                       "Q":(buffer_size,),
-                       "QL":(buffer_size,),
-                       "Vc":(buffer_size,),
-                       "theta":(buffer_size,),}
-        dict_file = {"cavity_phasor_record":(ring.h, total_size,),
-                     "beam_phasor_record":(ring.h, total_size,),
-                     "generator_phasor_record":(ring.h, total_size,),
-                     "ig_phasor_record":(ring.h, total_size,),
-                     "detune":(total_size,),
-                     "psi":(total_size,),
-                     "Vg":(total_size,),
-                     "theta_g":(total_size,),
-                     "Pg":(total_size,),
-                     "Rs":(total_size,),
-                     "Q":(total_size,),
-                     "QL":(total_size,),
-                     "Vc":(total_size,),
-                     "theta":(total_size,)}
-        dict_dtype = {"cavity_phasor_record":complex,
-                      "beam_phasor_record":complex,
-                      "generator_phasor_record":complex,
-                      "ig_phasor_record":complex,
-                      "detune":float,
-                      "psi":float,
-                      "Vg":float,
-                      "theta_g":float,
-                      "Pg":float,
-                      "Rs":float,
-                      "Q":float,
-                      "QL":float,
-                      "Vc":float,
-                      "theta":float}
-        
+        dict_buffer = {
+            "cavity_phasor_record": (
+                ring.h,
+                buffer_size,
+            ),
+            "beam_phasor_record": (
+                ring.h,
+                buffer_size,
+            ),
+            "generator_phasor_record": (
+                ring.h,
+                buffer_size,
+            ),
+            "ig_phasor_record": (
+                ring.h,
+                buffer_size,
+            ),
+            "detune": (buffer_size, ),
+            "psi": (buffer_size, ),
+            "Vg": (buffer_size, ),
+            "theta_g": (buffer_size, ),
+            "Pg": (buffer_size, ),
+            "Rs": (buffer_size, ),
+            "Q": (buffer_size, ),
+            "QL": (buffer_size, ),
+            "Vc": (buffer_size, ),
+            "theta": (buffer_size, ),
+        }
+        dict_file = {
+            "cavity_phasor_record": (
+                ring.h,
+                total_size,
+            ),
+            "beam_phasor_record": (
+                ring.h,
+                total_size,
+            ),
+            "generator_phasor_record": (
+                ring.h,
+                total_size,
+            ),
+            "ig_phasor_record": (
+                ring.h,
+                total_size,
+            ),
+            "detune": (total_size, ),
+            "psi": (total_size, ),
+            "Vg": (total_size, ),
+            "theta_g": (total_size, ),
+            "Pg": (total_size, ),
+            "Rs": (total_size, ),
+            "Q": (total_size, ),
+            "QL": (total_size, ),
+            "Vc": (total_size, ),
+            "theta": (total_size, )
+        }
+        dict_dtype = {
+            "cavity_phasor_record": complex,
+            "beam_phasor_record": complex,
+            "generator_phasor_record": complex,
+            "ig_phasor_record": complex,
+            "detune": float,
+            "psi": float,
+            "Vg": float,
+            "theta_g": float,
+            "Pg": float,
+            "Rs": float,
+            "Q": float,
+            "QL": float,
+            "Vc": float,
+            "theta": float
+        }
+
         self.monitor_init(group_name, save_every, buffer_size, total_size,
-                          dict_buffer, dict_file, file_name, mpi_mode, 
+                          dict_buffer, dict_file, file_name, mpi_mode,
                           dict_dtype)
-        
+
         self.dict_buffer = dict_buffer
         self.dict_file = dict_file
-                    
+
     def track(self, beam, cavity):
         """
         Save data
@@ -1454,7 +1638,7 @@ class CavityMonitor(Monitor):
         ----------
         beam : Beam object
         cavity : CavityResonator object
-        """        
+        """
         if self.track_count % self.save_every == 0:
             if isinstance(cavity, CavityResonator):
                 if beam.mpi_switch == False:
@@ -1463,6 +1647,6 @@ class CavityMonitor(Monitor):
                     self.to_buffer(cavity)
                 else:
                     pass
-            else:                            
+            else:
                 raise TypeError("cavity should be a CavityResonator object.")
-        self.track_count += 1       
\ No newline at end of file
+        self.track_count += 1
diff --git a/mbtrack2/tracking/monitors/plotting.py b/mbtrack2/tracking/monitors/plotting.py
index 30773006a323561d397a8d6834f7f8639174c0e1..e407151aa17b7450205afcf80f6909c56ff5f970 100644
--- a/mbtrack2/tracking/monitors/plotting.py
+++ b/mbtrack2/tracking/monitors/plotting.py
@@ -4,16 +4,23 @@ Module for plotting the data recorded by the monitor module during the
 tracking.
 """
 
-import numpy as np
-import matplotlib.pyplot as plt
+import random
+
+import h5py as hp
 import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
 import seaborn as sns
-import h5py as hp
-import random
 from scipy.stats import gmean
 
-def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean", 
-                  x_var="time", turn=None, legend=None):
+
+def plot_beamdata(filenames,
+                  dataset="mean",
+                  dimension="tau",
+                  stat_var="mean",
+                  x_var="time",
+                  turn=None,
+                  legend=None):
     """
     Plot 2D data recorded by BeamMonitor.
 
@@ -48,60 +55,69 @@ def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean",
         Figure object with the plot on it.
 
     """
-    
+
     if isinstance(filenames, str):
         filenames = [filenames]
-    
+
     fig, ax = plt.subplots()
-    
+
     for filename in filenames:
         file = hp.File(filename, "r")
         time = np.array(file["Beam"]["time"])
         data = np.array(file["Beam"][dataset])
-            
+
         if x_var == "time":
             x = time
             x_label = "Number of turns"
-            bunch_index = (file["Beam"]["current"][:,0] != 0).nonzero()[0]
+            bunch_index = (file["Beam"]["current"][:, 0] != 0).nonzero()[0]
             if dataset == "current":
-                y = np.nansum(data[bunch_index,:],0)*1e3
+                y = np.nansum(data[bunch_index, :], 0) * 1e3
                 y_label = "Total current (mA)"
             elif dataset == "emit":
-                dimension_dict = {"x":0, "y":1, "s":2}
+                dimension_dict = {"x": 0, "y": 1, "s": 2}
                 axis = dimension_dict[dimension]
-                label = ["$\\epsilon_{x}$ (m.rad)",
-                         "$\\epsilon_{y}$ (m.rad)",
-                         "$\\epsilon_{s}$ (s)"]
+                label = [
+                    "$\\epsilon_{x}$ (m.rad)", "$\\epsilon_{y}$ (m.rad)",
+                    "$\\epsilon_{s}$ (s)"
+                ]
                 if stat_var == "mean":
-                    y = np.nanmean(data[axis,bunch_index,:],0)
+                    y = np.nanmean(data[axis, bunch_index, :], 0)
                 elif stat_var == "std":
-                    y = np.nanstd(data[axis,bunch_index,:],0)
+                    y = np.nanstd(data[axis, bunch_index, :], 0)
                 y_label = stat_var + " " + label[axis]
             elif dataset == "cs_invariant":
-                dimension_dict = {"x":0, "y":1, "s":2}
+                dimension_dict = {"x": 0, "y": 1, "s": 2}
                 axis = dimension_dict[dimension]
                 label = ['$J_x$ (m)', '$J_y$ (m)', '$J_s$ (s)']
                 if stat_var == "mean":
-                    y = np.nanmean(data[axis,bunch_index,:],0)
+                    y = np.nanmean(data[axis, bunch_index, :], 0)
                 elif stat_var == "std":
-                    y = np.nanstd(data[axis,bunch_index,:],0)
+                    y = np.nanstd(data[axis, bunch_index, :], 0)
                 y_label = stat_var + " " + label[axis]
             elif dataset == "mean" or dataset == "std":
-                dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
-                                  "delta":5}
+                dimension_dict = {
+                    "x": 0,
+                    "xp": 1,
+                    "y": 2,
+                    "yp": 3,
+                    "tau": 4,
+                    "delta": 5
+                }
                 axis = dimension_dict[dimension]
                 scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-                label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
-                         "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
-                if stat_var == "mean":   
-                    y = np.nanmean(data[axis,bunch_index,:],0)*scale[axis]
-                elif stat_var == "std":      
-                    y = np.nanstd(data[axis,bunch_index,:],0)*scale[axis]
-                label_sup = {"mean":"mean of ", "std":"std of "}
+                label = [
+                    "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
+                    "$\\tau$ (ps)", "$\\delta$"
+                ]
+                if stat_var == "mean":
+                    y = np.nanmean(data[axis, bunch_index, :], 0) * scale[axis]
+                elif stat_var == "std":
+                    y = np.nanstd(data[axis, bunch_index, :], 0) * scale[axis]
+                label_sup = {"mean": "mean of ", "std": "std of "}
                 y_label = label_sup[stat_var] + dataset + " " + label[axis]
-                
+
         elif x_var == "index":
-            h = len(file["Beam"]["mean"][0,:,0])
+            h = len(file["Beam"]["mean"][0, :, 0])
             x = np.arange(h)
             x_label = "Bunch index"
             if turn is None:
@@ -110,48 +126,58 @@ def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean",
                 idx = np.where(time == int(turn))[0]
                 if (idx.size == 0):
                     raise ValueError("Turn is not valid.")
-            
+
             if dataset == "current":
-                y = data[:,idx]*1e3
+                y = data[:, idx] * 1e3
                 y_label = "Bunch current (mA)"
             elif dataset == "emit":
-                dimension_dict = {"x":0, "y":1, "s":2}
+                dimension_dict = {"x": 0, "y": 1, "s": 2}
                 axis = dimension_dict[dimension]
-                label = ["$\\epsilon_{x}$ (m.rad)",
-                         "$\\epsilon_{y}$ (m.rad)",
-                         "$\\epsilon_{s}$ (s)"]
-                y = data[axis,:,idx]
+                label = [
+                    "$\\epsilon_{x}$ (m.rad)", "$\\epsilon_{y}$ (m.rad)",
+                    "$\\epsilon_{s}$ (s)"
+                ]
+                y = data[axis, :, idx]
                 y_label = label[axis]
             elif dataset == "cs_invariant":
-                dimension_dict = {"x":0, "y":1, "s":2}
+                dimension_dict = {"x": 0, "y": 1, "s": 2}
                 axis = dimension_dict[dimension]
                 label = ['$J_x$ (m)', '$J_y$ (m)', '$J_s$ (s)']
-                y = data[axis,:,idx]
+                y = data[axis, :, idx]
                 y_label = label[axis]
             elif dataset == "mean" or dataset == "std":
-                dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
-                                  "delta":5}
+                dimension_dict = {
+                    "x": 0,
+                    "xp": 1,
+                    "y": 2,
+                    "yp": 3,
+                    "tau": 4,
+                    "delta": 5
+                }
                 axis = dimension_dict[dimension]
                 scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-                label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
-                         "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
-                y = data[axis,:,idx]*scale[axis]
+                label = [
+                    "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
+                    "$\\tau$ (ps)", "$\\delta$"
+                ]
+                y = data[axis, :, idx] * scale[axis]
                 y_label = dataset + " " + label[axis]
         else:
             raise ValueError("x_var should be time or index")
-            
+
         y = np.squeeze(y)
-        
+
         ax.plot(x, y)
         ax.set_xlabel(x_label)
         ax.set_ylabel(y_label)
         if legend is not None:
             plt.legend(legend)
-            
+
         file.close()
-        
+
     return fig
-            
+
+
 def streak_beamdata(filename, dataset="mean", dimension="tau", cm_lim=None):
     """
     Plot 3D data recorded by BeamMonitor.
@@ -178,72 +204,89 @@ def streak_beamdata(filename, dataset="mean", dimension="tau", cm_lim=None):
         Figure object with the plot on it.
 
     """
-    
+
     file = hp.File(filename, "r")
     data = file["Beam"]
     time = np.array(data["time"])
-        
-    h = len(data["mean"][0,:,0])
+
+    h = len(data["mean"][0, :, 0])
     x = np.arange(h)
     x_label = "Bunch index"
     y = time
     y_label = "Number of turns"
     if dataset == "current":
-        z = (np.array(data["current"])*1e3).T
+        z = (np.array(data["current"]) * 1e3).T
         z_label = "Bunch current (mA)"
         title = z_label
     elif dataset == "emit":
-        dimension_dict = {"x":0, "y":1, "s":2}
+        dimension_dict = {"x": 0, "y": 1, "s": 2}
         axis = dimension_dict[dimension]
-        label = ["$\\epsilon_{x}$ (m.rad)",
-                 "$\\epsilon_{y}$ (m.rad)",
-                 "$\\epsilon_{s}$ (s)"]
-        z = np.array(data["emit"][axis,:,:]).T
+        label = [
+            "$\\epsilon_{x}$ (m.rad)", "$\\epsilon_{y}$ (m.rad)",
+            "$\\epsilon_{s}$ (s)"
+        ]
+        z = np.array(data["emit"][axis, :, :]).T
         z_label = label[axis]
         title = z_label
     elif dataset == "cs_invariant":
-        dimension_dict = {"x":0, "y":1, "s":2}
+        dimension_dict = {"x": 0, "y": 1, "s": 2}
         axis = dimension_dict[dimension]
         label = ['$J_x$ (m)', '$J_y$ (m)', '$J_s$ (s)']
-        z = np.array(data["cs_invariant"][axis,:,:]).T
+        z = np.array(data["cs_invariant"][axis, :, :]).T
         z_label = label[axis]
         title = z_label
     else:
-        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
-                              "delta":5}
+        dimension_dict = {
+            "x": 0,
+            "xp": 1,
+            "y": 2,
+            "yp": 3,
+            "tau": 4,
+            "delta": 5
+        }
         axis = dimension_dict[dimension]
         scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
-                     "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
-        z = np.array(data[dataset][axis,:,:]).T*scale[axis]
+        label = [
+            "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
+            "$\\tau$ (ps)", "$\\delta$"
+        ]
+        z = np.array(data[dataset][axis, :, :]).T * scale[axis]
         z_label = label[axis]
         if dataset == "mean":
             title = label[axis] + " CM"
         elif dataset == "std":
             title = label[axis] + " RMS"
-            
+
     fig, ax = plt.subplots()
     ax.set_xlabel(x_label)
     ax.set_ylabel(y_label)
     ax.set_title(title)
-    
+
     if dataset == "mean":
-        cmap = mpl.cm.coolwarm # diverging
+        cmap = mpl.cm.coolwarm  # diverging
     else:
-        cmap = mpl.cm.inferno # sequential
-    
-    c = ax.imshow(z, cmap=cmap, origin='lower' , aspect='auto',
-            extent=[x.min(), x.max(), y.min(), y.max()])
+        cmap = mpl.cm.inferno  # sequential
+
+    c = ax.imshow(z,
+                  cmap=cmap,
+                  origin='lower',
+                  aspect='auto',
+                  extent=[x.min(), x.max(), y.min(),
+                          y.max()])
     if cm_lim is not None:
-        c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
+        c.set_clim(vmin=cm_lim[0], vmax=cm_lim[1])
     cbar = fig.colorbar(c, ax=ax)
     cbar.set_label(z_label)
-    
+
     file.close()
-        
+
     return fig
-              
-def plot_bunchdata(filenames, bunch_number, dataset, dimension="x", 
+
+
+def plot_bunchdata(filenames,
+                   bunch_number,
+                   dataset,
+                   dimension="x",
                    legend=None):
     """
     Plot data recorded by BunchMonitor.
@@ -273,54 +316,65 @@ def plot_bunchdata(filenames, bunch_number, dataset, dimension="x",
         Figure object with the plot on it.
 
     """
-    
+
     if isinstance(filenames, str):
         filenames = [filenames]
-        
+
     if isinstance(bunch_number, int):
         ll = []
         for i in range(len(filenames)):
             ll.append(bunch_number)
         bunch_number = ll
-        
+
     fig, ax = plt.subplots()
-    
+
     for i, filename in enumerate(filenames):
         file = hp.File(filename, "r")
-        group = "BunchData_{0}".format(bunch_number[i])  # Data group of the HDF5 file
-        
+        group = "BunchData_{0}".format(
+            bunch_number[i])  # Data group of the HDF5 file
+
         if dataset == "current":
-            y_var = file[group][dataset][:]*1e3
+            y_var = file[group][dataset][:] * 1e3
             label = "current (mA)"
-            
+
         elif dataset == "emit":
-            dimension_dict = {"x":0, "y":1, "s":2}
-                             
+            dimension_dict = {"x": 0, "y": 1, "s": 2}
+
             y_var = file[group][dataset][dimension_dict[dimension]]
-            
+
             if dimension == "x": label = "hor. emittance (m.rad)"
             elif dimension == "y": label = "ver. emittance (m.rad)"
             elif dimension == "s": label = "long. emittance (s)"
-            
-            
-        elif dataset == "mean" or dataset == "std":                        
-            dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
-            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
+
+        elif dataset == "mean" or dataset == "std":
+            dimension_dict = {
+                "x": 0,
+                "xp": 1,
+                "y": 2,
+                "yp": 3,
+                "tau": 4,
+                "delta": 5
+            }
+            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
             axis_index = dimension_dict[dimension]
-            
-            y_var = file[group][dataset][axis_index]*scale[axis_index]
+
+            y_var = file[group][dataset][axis_index] * scale[axis_index]
             if dataset == "mean":
-                label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
-                              "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
+                label_list = [
+                    "x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
+                    "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"
+                ]
             else:
-                label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
-                              "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
-                              "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
-            
+                label_list = [
+                    "$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
+                    "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)",
+                    "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"
+                ]
+
             label = label_list[axis_index]
-            
+
         elif dataset == "cs_invariant":
-            dimension_dict = {"x":0, "y":1, "s":2}
+            dimension_dict = {"x": 0, "y": 1, "s": 2}
             axis_index = dimension_dict[dimension]
             y_var = file[group][dataset][axis_index]
             label_list = ['$J_x$ (m)', '$J_y$ (m)', '$J_s$ (s)']
@@ -328,19 +382,26 @@ def plot_bunchdata(filenames, bunch_number, dataset, dimension="x",
 
         x_axis = file[group]["time"][:]
         xlabel = "Number of turns"
-        
+
         ax.plot(x_axis, y_var)
         ax.set_xlabel(xlabel)
         ax.set_ylabel(label)
         if legend is not None:
             plt.legend(legend)
-            
+
         file.close()
-        
+
     return fig
-            
-def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
-                        only_alive=True, plot_size=1, plot_kind='kde'):
+
+
+def plot_phasespacedata(filename,
+                        bunch_number,
+                        x_var,
+                        y_var,
+                        turn,
+                        only_alive=True,
+                        plot_size=1,
+                        plot_kind='kde'):
     """
     Plot data recorded by PhaseSpaceMonitor.
 
@@ -371,56 +432,66 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
     fig : Figure
         Figure object with the plot on it.
     """
-    
+
     file = hp.File(filename, "r")
-    
+
     group = "PhaseSpaceData_{0}".format(bunch_number)
     dataset = "particles"
 
-    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-    scale = [1e3,1e3,1e3,1e3,1e12,1]
-    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
-             "$\\delta$"]
-    
+    option_dict = {"x": 0, "xp": 1, "y": 2, "yp": 3, "tau": 4, "delta": 5}
+    scale = [1e3, 1e3, 1e3, 1e3, 1e12, 1]
+    label = [
+        "x (mm)", "x' (mrad)", "y (mm)", "y' (mrad)", "$\\tau$ (ps)",
+        "$\\delta$"
+    ]
+
     # find the index of "turn" in time array
-    turn_index = np.where(file[group]["time"][:]==turn) 
-    
+    turn_index = np.where(file[group]["time"][:] == turn)
+
     if len(turn_index[0]) == 0:
-        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
-                         format(turn, file[group]["time"][:]))     
-    
+        raise ValueError("Turn {0} is not found. Enter turn from {1}.".format(
+            turn, file[group]["time"][:]))
+
     path = file[group][dataset]
-    mp_number = path[:,0,0].size
+    mp_number = path[:, 0, 0].size
 
     if only_alive is True:
         data = np.array(file[group]["alive"])
-        index = np.where(data[:,turn_index])[0]
+        index = np.where(data[:, turn_index])[0]
     else:
         index = np.arange(mp_number)
-        
+
     if plot_size == 1:
         samples = index
     elif plot_size < 1:
-        samples_meta = random.sample(list(index), int(plot_size*mp_number))
+        samples_meta = random.sample(list(index), int(plot_size * mp_number))
         samples = sorted(samples_meta)
     else:
         raise ValueError("plot_size must be in range [0,1].")
-            
+
     # format : sns.jointplot(x_axis, yaxis, kind)
-    x_axis = path[samples,option_dict[x_var],turn_index[0][0]]
-    y_axis = path[samples,option_dict[y_var],turn_index[0][0]]    
-        
-    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
-                        y_axis*scale[option_dict[y_var]], kind=plot_kind)
-   
+    x_axis = path[samples, option_dict[x_var], turn_index[0][0]]
+    y_axis = path[samples, option_dict[y_var], turn_index[0][0]]
+
+    fig = sns.jointplot(x_axis * scale[option_dict[x_var]],
+                        y_axis * scale[option_dict[y_var]],
+                        kind=plot_kind)
+
     plt.xlabel(label[option_dict[x_var]])
     plt.ylabel(label[option_dict[y_var]])
-            
+
     file.close()
     return fig
 
-def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
-                     stop=None, step=None, profile_plot=True, streak_plot=True):
+
+def plot_profiledata(filename,
+                     bunch_number,
+                     dimension="tau",
+                     start=0,
+                     stop=None,
+                     step=None,
+                     profile_plot=True,
+                     streak_plot=True):
     """
     Plot data recorded by ProfileMonitor
 
@@ -452,74 +523,80 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
         Figure object with the plot on it.
 
     """
-    
+
     file = hp.File(filename, "r")
     path = file['ProfileData_{0}'.format(bunch_number)]
     l_bound = np.array(path["{0}_bin".format(dimension)])
     data = np.array(path[dimension])
     time = np.array(path["time"])
-    
+
     if stop is None:
         stop = time[-1]
     elif stop not in time:
-        raise ValueError("stop not found. Choose from {0}"
-                         .format(time[:]))
- 
+        raise ValueError("stop not found. Choose from {0}".format(time[:]))
+
     if start not in time:
-        raise ValueError("start not found. Choose from {0}"
-                         .format(time[:]))
-    
+        raise ValueError("start not found. Choose from {0}".format(time[:]))
+
     save_every = time[1] - time[0]
-    
+
     if step is None:
         step = save_every
-    
+
     if step % save_every != 0:
         raise ValueError("step must be divisible by the recording step "
                          "which is {0}.".format(save_every))
-    
-    dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+
+    dimension_dict = {"x": 0, "xp": 1, "y": 2, "yp": 3, "tau": 4, "delta": 5}
     scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-    label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-             "$\\tau$ (ps)", "$\\delta$"]
-    
-    num = int((stop - start)/step)
-    n_bin = len(data[:,0])
-    
+    label = [
+        "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", "$\\tau$ (ps)",
+        "$\\delta$"
+    ]
+
+    num = int((stop-start) / step)
+    n_bin = len(data[:, 0])
+
     start_index = np.where(time[:] == start)[0][0]
 
-    x_var = np.zeros((num+1,n_bin))
-    turn_index_array = np.zeros((num+1,), dtype=int)
-    for i in range(num+1):
-        turn_index = int(start_index + i * step / save_every)
+    x_var = np.zeros((num + 1, n_bin))
+    turn_index_array = np.zeros((num + 1, ), dtype=int)
+    for i in range(num + 1):
+        turn_index = int(start_index + i*step/save_every)
         turn_index_array[i] = turn_index
         # construct an array of bin mids
-        x_var[i,:] = l_bound[:,turn_index]
-        
+        x_var[i, :] = l_bound[:, turn_index]
+
     if profile_plot is True:
         fig, ax = plt.subplots()
-        for i in range(num+1):
-            ax.plot(x_var[i]*scale[dimension_dict[dimension]],
-                    data[:,turn_index_array[i]], 
+        for i in range(num + 1):
+            ax.plot(x_var[i] * scale[dimension_dict[dimension]],
+                    data[:, turn_index_array[i]],
                     label="turn {0}".format(time[turn_index_array[i]]))
         ax.set_xlabel(label[dimension_dict[dimension]])
-        ax.set_ylabel("number of macro-particles")         
+        ax.set_ylabel("number of macro-particles")
         ax.legend()
-            
+
     if streak_plot is True:
-        turn = np.reshape(time[turn_index_array], (num+1,1))
-        y_var = np.ones((num+1,n_bin)) * turn
-        z_var = np.transpose(data[:,turn_index_array])
+        turn = np.reshape(time[turn_index_array], (num + 1, 1))
+        y_var = np.ones((num + 1, n_bin)) * turn
+        z_var = np.transpose(data[:, turn_index_array])
         fig2, ax2 = plt.subplots()
-        cmap = mpl.cm.inferno # sequential
-        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
-                       extent=[x_var.min()*scale[dimension_dict[dimension]],
-                               x_var.max()*scale[dimension_dict[dimension]],
-                               y_var.min(),y_var.max()])
+        cmap = mpl.cm.inferno  # sequential
+        c = ax2.imshow(z_var,
+                       cmap=cmap,
+                       origin='lower',
+                       aspect='auto',
+                       extent=[
+                           x_var.min() * scale[dimension_dict[dimension]],
+                           x_var.max() * scale[dimension_dict[dimension]],
+                           y_var.min(),
+                           y_var.max()
+                       ])
         ax2.set_xlabel(label[dimension_dict[dimension]])
         ax2.set_ylabel("Number of turns")
         cbar = fig2.colorbar(c, ax=ax2)
-        cbar.set_label("Number of macro-particles") 
+        cbar.set_label("Number of macro-particles")
 
     file.close()
     if profile_plot is True and streak_plot is True:
@@ -528,10 +605,18 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
         return fig
     elif streak_plot is True:
         return fig2
-    
-def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
-                     stop=None, step=None, profile_plot=False, streak_plot=True,
-                     bunch_profile=False, dipole=False):
+
+
+def plot_wakedata(filename,
+                  bunch_number,
+                  wake_type="Wlong",
+                  start=0,
+                  stop=None,
+                  step=None,
+                  profile_plot=False,
+                  streak_plot=True,
+                  bunch_profile=False,
+                  dipole=False):
     """
     Plot data recorded by WakePotentialMonitor
 
@@ -567,91 +652,105 @@ def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
         Figure object with the plot on it.
 
     """
-    
+
     file = hp.File(filename, "r")
     path = file['WakePotentialData_{0}'.format(bunch_number)]
     time = np.array(path["time"])
-    
+
     if stop is None:
         stop = time[-1]
     elif stop not in time:
-        raise ValueError("stop not found. Choose from {0}"
-                         .format(time[:]))
- 
+        raise ValueError("stop not found. Choose from {0}".format(time[:]))
+
     if start not in time:
-        raise ValueError("start not found. Choose from {0}"
-                         .format(time[:]))
-    
-    save_every = time[1] -time[0]
-    
+        raise ValueError("start not found. Choose from {0}".format(time[:]))
+
+    save_every = time[1] - time[0]
+
     if step is None:
         step = save_every
-    
+
     if step % save_every != 0:
         raise ValueError("step must be divisible by the recording step "
                          "which is {0}.".format(save_every))
-    
-    dimension_dict = {"Wlong":0, "Wxdip":1, "Wydip":2, "Wxquad":3, "Wyquad":4}
+
+    dimension_dict = {
+        "Wlong": 0,
+        "Wxdip": 1,
+        "Wydip": 2,
+        "Wxquad": 3,
+        "Wyquad": 4
+    }
     scale = [1e-12, 1e-12, 1e-12, 1e-15, 1e-15]
-    label = ["$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$", "$W_{p,x}^Q (V/pC/mm)$",
-             "$W_{p,y}^Q (V/pC/mm)$"]
-    
+    label = [
+        "$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$",
+        "$W_{p,x}^Q (V/pC/mm)$", "$W_{p,y}^Q (V/pC/mm)$"
+    ]
+
     if bunch_profile == True:
         tau_name = "tau_" + wake_type
         wake_type = "profile_" + wake_type
-        dimension_dict = {wake_type:0}
+        dimension_dict = {wake_type: 0}
         scale = [1]
         label = ["$\\rho$ (a.u.)"]
-        cmap = mpl.cm.inferno # sequential
+        cmap = mpl.cm.inferno  # sequential
     elif dipole == True:
         tau_name = "tau_" + wake_type
         wake_type = "dipole_" + wake_type
-        dimension_dict = {wake_type:0}
+        dimension_dict = {wake_type: 0}
         scale = [1]
         label = ["Dipole moment (m)"]
-        cmap = mpl.cm.coolwarm # diverging
+        cmap = mpl.cm.coolwarm  # diverging
     else:
         tau_name = "tau_" + wake_type
-        cmap = mpl.cm.coolwarm # diverging
-        
+        cmap = mpl.cm.coolwarm  # diverging
+
     data = np.array(path[wake_type])
-        
-    num = int((stop - start)/step)
-    n_bin = len(data[:,0])
-    
+
+    num = int((stop-start) / step)
+    n_bin = len(data[:, 0])
+
     start_index = np.where(time[:] == start)[0][0]
-    
-    x_var = np.zeros((num+1,n_bin))
-    turn_index_array = np.zeros((num+1,), dtype=int)
-    for i in range(num+1):
-        turn_index = int(start_index + i * step / save_every)
+
+    x_var = np.zeros((num + 1, n_bin))
+    turn_index_array = np.zeros((num + 1, ), dtype=int)
+    for i in range(num + 1):
+        turn_index = int(start_index + i*step/save_every)
         turn_index_array[i] = turn_index
         # construct an array of bin mids
-        x_var[i,:] = np.array(path[tau_name])[:,turn_index]
-                
+        x_var[i, :] = np.array(path[tau_name])[:, turn_index]
+
     if profile_plot is True:
         fig, ax = plt.subplots()
-        for i in range(num+1):
-            ax.plot(x_var[i]*1e12,
-                    data[:,turn_index_array[i]]*scale[dimension_dict[wake_type]], 
+        for i in range(num + 1):
+            ax.plot(x_var[i] * 1e12,
+                    data[:, turn_index_array[i]] *
+                    scale[dimension_dict[wake_type]],
                     label="turn {0}".format(time[turn_index_array[i]]))
         ax.set_xlabel("$\\tau$ (ps)")
-        ax.set_ylabel(label[dimension_dict[wake_type]])         
+        ax.set_ylabel(label[dimension_dict[wake_type]])
         ax.legend()
-            
+
     if streak_plot is True:
-        turn = np.reshape(time[turn_index_array], (num+1,1))
-        y_var = np.ones((num+1,n_bin)) * turn
-        z_var = np.transpose(data[:,turn_index_array]*scale[dimension_dict[wake_type]])
+        turn = np.reshape(time[turn_index_array], (num + 1, 1))
+        y_var = np.ones((num + 1, n_bin)) * turn
+        z_var = np.transpose(data[:, turn_index_array] *
+                             scale[dimension_dict[wake_type]])
         fig2, ax2 = plt.subplots()
-        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
-                       extent=[x_var.min()*1e12,
-                               x_var.max()*1e12,
-                               y_var.min(),y_var.max()])
+        c = ax2.imshow(z_var,
+                       cmap=cmap,
+                       origin='lower',
+                       aspect='auto',
+                       extent=[
+                           x_var.min() * 1e12,
+                           x_var.max() * 1e12,
+                           y_var.min(),
+                           y_var.max()
+                       ])
         ax2.set_xlabel("$\\tau$ (ps)")
         ax2.set_ylabel("Number of turns")
         cbar = fig2.colorbar(c, ax=ax2)
-        cbar.set_label(label[dimension_dict[wake_type]]) 
+        cbar.set_label(label[dimension_dict[wake_type]])
 
     file.close()
     if profile_plot is True and streak_plot is True:
@@ -660,9 +759,16 @@ def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
         return fig
     elif streak_plot is True:
         return fig2
-    
-def plot_bunchspectrum(filenames, bunch_number, dataset="incoherent", dim="tau",
-                       turns=None, fs=None, log_scale=True, legend=None,
+
+
+def plot_bunchspectrum(filenames,
+                       bunch_number,
+                       dataset="incoherent",
+                       dim="tau",
+                       turns=None,
+                       fs=None,
+                       log_scale=True,
+                       legend=None,
                        norm=False):
     """
     Plot coherent and incoherent spectrum data.
@@ -702,34 +808,34 @@ def plot_bunchspectrum(filenames, bunch_number, dataset="incoherent", dim="tau",
     fig : Figure
 
     """
-    
+
     if isinstance(filenames, str):
         filenames = [filenames]
-        
+
     if isinstance(bunch_number, int):
         ll = []
         for i in range(len(filenames)):
             ll.append(bunch_number)
         bunch_number = ll
-        
+
     fig, ax = plt.subplots()
-    
+
     for i, filename in enumerate(filenames):
         file = hp.File(filename, "r")
         group = file["BunchSpectrum_{0}".format(bunch_number[i])]
-        
+
         time = np.array(group["time"])
         freq = np.array(group["freq"])
-        dim_dict = {"x":0, "y":1, "tau":2}
-        
+        dim_dict = {"x": 0, "y": 1, "tau": 2}
+
         if dataset == "mean_incoherent":
-            y_var = group["mean_incoherent"][dim_dict[dim],:]
-            y_err = group["std_incoherent"][dim_dict[dim],:]
+            y_var = group["mean_incoherent"][dim_dict[dim], :]
+            y_err = group["std_incoherent"][dim_dict[dim], :]
             ax.errorbar(time, y_var, y_err)
             xlabel = "Turn number"
             ylabel = "Mean incoherent frequency [Hz]"
         elif dataset == "incoherent" or dataset == "coherent":
-            
+
             if turns is None:
                 turn_index = np.where(time == time)[0]
             else:
@@ -738,40 +844,49 @@ def plot_bunchspectrum(filenames, bunch_number, dataset="incoherent", dim="tau",
                     idx = np.where(time == turn)[0][0]
                     turn_index.append(idx)
                 turn_index = np.array(turn_index)
-                
+
             if fs is None:
                 x_var = freq
                 xlabel = "Frequency [Hz]"
             else:
-                x_var = freq/fs
+                x_var = freq / fs
                 xlabel = r"$f/f_{s}$"
-                
+
             for idx in turn_index:
-                y_var = group[dataset][dim_dict[dim],:,idx]
+                y_var = group[dataset][dim_dict[dim], :, idx]
                 if norm is True:
-                    y_var = y_var/gmean(y_var)
+                    y_var = y_var / gmean(y_var)
                 ax.plot(x_var, y_var)
-                
+
             if log_scale is True:
                 ax.set_yscale('log')
-                
+
             ylabel = "FFT amplitude [a.u.]"
             if dataset == "incoherent":
                 ax.set_title("Incoherent spectrum")
             elif dataset == "coherent":
-                ax.set_title("Coherent spectrum")            
-        
+                ax.set_title("Coherent spectrum")
+
         ax.set_xlabel(xlabel)
         ax.set_ylabel(ylabel)
         if legend is not None:
             plt.legend(legend)
         file.close()
-        
+
     return fig
 
-def streak_bunchspectrum(filename, bunch_number, dataset="incoherent", 
-                         dim="tau", fs=None, log_scale=True, fmin=None, 
-                         fmax=None, turns=None, norm=False, ylim=None):
+
+def streak_bunchspectrum(filename,
+                         bunch_number,
+                         dataset="incoherent",
+                         dim="tau",
+                         fs=None,
+                         log_scale=True,
+                         fmin=None,
+                         fmax=None,
+                         turns=None,
+                         norm=False,
+                         ylim=None):
     """
     Plot 3D data recorded by the BunchSpectrumMonitor.
 
@@ -812,14 +927,14 @@ def streak_bunchspectrum(filename, bunch_number, dataset="incoherent",
     fig : Figure
 
     """
-    
+
     file = hp.File(filename, "r")
     group = file["BunchSpectrum_{0}".format(bunch_number)]
-    
+
     time = np.array(group["time"])
     freq = np.array(group["freq"])
-    dim_dict = {"x":0, "y":1, "tau":2}
-    
+    dim_dict = {"x": 0, "y": 1, "tau": 2}
+
     if turns is None:
         turn_index = np.where(time == time)[0]
         if ylim is None:
@@ -836,57 +951,67 @@ def streak_bunchspectrum(filename, bunch_number, dataset="incoherent",
             idx = np.where(time == turn)[0][0]
             turn_index.append(idx)
         turn_index = np.array(turn_index)
-    
+
     data = group[dataset][dim_dict[dim], :, turn_index]
-    
+
     if log_scale is True:
         option = mpl.colors.LogNorm()
     else:
         option = None
-    
+
     if fs is None:
         x_var = freq
         xlabel = "Frequency [Hz]"
     else:
-        x_var = freq/fs
+        x_var = freq / fs
         xlabel = r"$f/f_{s}$"
-        
+
     if fmin is None:
         fmin = x_var.min()
     if fmax is None:
         fmax = x_var.max()
-        
+
     ind = (x_var > fmin) & (x_var < fmax)
-    x_var=x_var[ind]
-    data = data[ind,:]
-    
+    x_var = x_var[ind]
+    data = data[ind, :]
+
     if norm is True:
-        data = data/gmean(data)
-    
+        data = data / gmean(data)
+
     if ylim is None:
         ylabel = "Turn number"
     else:
         ylabel = ""
-    
+
     fig, ax = plt.subplots()
     if dataset == "incoherent":
         ax.set_title("Incoherent spectrum")
     elif dataset == "coherent":
-        ax.set_title("Coherent spectrum")   
-        
-    cmap = mpl.cm.inferno # sequential
-    c = ax.imshow(data.T, cmap=cmap, origin='lower' , aspect='auto',
+        ax.set_title("Coherent spectrum")
+
+    cmap = mpl.cm.inferno  # sequential
+    c = ax.imshow(data.T,
+                  cmap=cmap,
+                  origin='lower',
+                  aspect='auto',
                   extent=[x_var.min(), x_var.max(), tmin, tmax],
-                  norm=option, interpolation="none")
+                  norm=option,
+                  interpolation="none")
     cbar = fig.colorbar(c, ax=ax)
     cbar.ax.set_ylabel("FFT amplitude [a.u.]", rotation=270)
     ax.set_xlabel(xlabel)
     ax.set_ylabel(ylabel)
-    
+
     return fig
 
-def plot_beamspectrum(filenames, dim="tau", turns=None, f0=None, 
-                      log_scale=True, legend=None, norm=False):
+
+def plot_beamspectrum(filenames,
+                      dim="tau",
+                      turns=None,
+                      f0=None,
+                      log_scale=True,
+                      legend=None,
+                      norm=False):
     """
     Plot coherent beam spectrum data.
 
@@ -919,21 +1044,21 @@ def plot_beamspectrum(filenames, dim="tau", turns=None, f0=None,
     fig : Figure
 
     """
-    
+
     if isinstance(filenames, str):
         filenames = [filenames]
-        
+
     fig, ax = plt.subplots()
-    
+
     for i, filename in enumerate(filenames):
         file = hp.File(filename, "r")
         group = file["BeamSpectrum"]
-        
+
         dataset = "coherent"
         time = np.array(group["time"])
         freq = np.array(group["freq"])
-        dim_dict = {"x":0, "y":1, "tau":2}
-            
+        dim_dict = {"x": 0, "y": 1, "tau": 2}
+
         if turns is None:
             turn_index = np.where(time == time)[0]
         else:
@@ -942,36 +1067,44 @@ def plot_beamspectrum(filenames, dim="tau", turns=None, f0=None,
                 idx = np.where(time == turn)[0][0]
                 turn_index.append(idx)
             turn_index = np.array(turn_index)
-            
+
         if f0 is None:
             x_var = freq
             xlabel = "Frequency [Hz]"
         else:
-            x_var = freq/f0
+            x_var = freq / f0
             xlabel = r"$f/f_{0}$"
-            
+
         for idx in turn_index:
-            y_var = group[dataset][dim_dict[dim],:,idx]
+            y_var = group[dataset][dim_dict[dim], :, idx]
             if norm is True:
-                y_var = y_var/gmean(y_var)
+                y_var = y_var / gmean(y_var)
             ax.plot(x_var, y_var)
-            
+
         if log_scale is True:
             ax.set_yscale('log')
-            
+
         ylabel = "FFT amplitude [a.u.]"
-        ax.set_title("Beam coherent spectrum")            
-        
+        ax.set_title("Beam coherent spectrum")
+
         ax.set_xlabel(xlabel)
         ax.set_ylabel(ylabel)
         if legend is not None:
             plt.legend(legend)
         file.close()
-        
+
     return fig
 
-def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None, 
-                         fmax=None, turns=None, norm=False, ylim=None):
+
+def streak_beamspectrum(filename,
+                        dim="tau",
+                        f0=None,
+                        log_scale=True,
+                        fmin=None,
+                        fmax=None,
+                        turns=None,
+                        norm=False,
+                        ylim=None):
     """
     Plot 3D data recorded by the BeamSpectrumMonitor.
 
@@ -1005,14 +1138,14 @@ def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None,
     fig : Figure
 
     """
-    
+
     file = hp.File(filename, "r")
     group = file["BeamSpectrum"]
-    dataset="coherent"
+    dataset = "coherent"
     time = np.array(group["time"])
     freq = np.array(group["freq"])
-    dim_dict = {"x":0, "y":1, "tau":2}
-    
+    dim_dict = {"x": 0, "y": 1, "tau": 2}
+
     if turns is None:
         turn_index = np.where(time == time)[0]
         if ylim is None:
@@ -1029,54 +1162,64 @@ def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None,
             idx = np.where(time == turn)[0][0]
             turn_index.append(idx)
         turn_index = np.array(turn_index)
-    
+
     data = group[dataset][dim_dict[dim], :, turn_index]
-    
+
     if log_scale is True:
         option = mpl.colors.LogNorm()
     else:
         option = None
-    
+
     if f0 is None:
         x_var = freq
         xlabel = "Frequency [Hz]"
     else:
-        x_var = freq/f0
+        x_var = freq / f0
         xlabel = r"$f/f_{0}$"
-        
+
     if fmin is None:
         fmin = x_var.min()
     if fmax is None:
         fmax = x_var.max()
-        
+
     ind = (x_var > fmin) & (x_var < fmax)
-    x_var=x_var[ind]
-    data = data[ind,:]
-    
+    x_var = x_var[ind]
+    data = data[ind, :]
+
     if norm is True:
-        data = data/gmean(data)
-        
+        data = data / gmean(data)
+
     if ylim is None:
         ylabel = "Turn number"
     else:
         ylabel = ""
-    
+
     fig, ax = plt.subplots()
-    ax.set_title("Beam coherent spectrum")   
-        
-    cmap = mpl.cm.inferno # sequential
-    c = ax.imshow(data.T, cmap=cmap, origin='lower' , aspect='auto',
+    ax.set_title("Beam coherent spectrum")
+
+    cmap = mpl.cm.inferno  # sequential
+    c = ax.imshow(data.T,
+                  cmap=cmap,
+                  origin='lower',
+                  aspect='auto',
                   extent=[x_var.min(), x_var.max(), tmin, tmax],
-                  norm=option, interpolation="none")
+                  norm=option,
+                  interpolation="none")
     cbar = fig.colorbar(c, ax=ax)
     cbar.ax.set_ylabel("FFT amplitude [a.u.]", rotation=270)
     ax.set_xlabel(xlabel)
     ax.set_ylabel(ylabel)
-    
+
     return fig
 
-def plot_cavitydata(filename, cavity_name, phasor="cavity", 
-                    plot_type="bunch", bunch_number=0, turn=None, cm_lim=None,
+
+def plot_cavitydata(filename,
+                    cavity_name,
+                    phasor="cavity",
+                    plot_type="bunch",
+                    bunch_number=0,
+                    turn=None,
+                    cm_lim=None,
                     show_objective=False):
     """
     Plot data recorded by CavityMonitor.
@@ -1128,157 +1271,186 @@ def plot_cavitydata(filename, cavity_name, phasor="cavity",
     """
     file = hp.File(filename, "r")
     cavity_data = file[cavity_name]
-    
+
     time = np.array(cavity_data["time"])
-    
-    ph = {"cavity":0, "beam":1, "generator":2, "ig":3}
+
+    ph = {"cavity": 0, "beam": 1, "generator": 2, "ig": 3}
     labels = ["Cavity", "Beam", "Generator", "Generator"]
     units = [" voltage [MV]", " voltage [MV]", " voltage [MV]", " current [A]"]
     units_val = [1e-6, 1e-6, 1e-6, 1]
-    
+
     if (plot_type == "bunch") or (plot_type == "mean"):
-    
+
         if plot_type == "bunch":
-            data = [cavity_data["cavity_phasor_record"][bunch_number,:], 
-                    cavity_data["beam_phasor_record"][bunch_number,:],
-                    cavity_data["generator_phasor_record"][bunch_number,:],
-                    cavity_data["ig_phasor_record"][bunch_number,:]]
+            data = [
+                cavity_data["cavity_phasor_record"][bunch_number, :],
+                cavity_data["beam_phasor_record"][bunch_number, :],
+                cavity_data["generator_phasor_record"][bunch_number, :],
+                cavity_data["ig_phasor_record"][bunch_number, :]
+            ]
         elif plot_type == "mean":
             try:
-                bunch_index = (file["Beam"]["current"][:,0] != 0).nonzero()[0]
+                bunch_index = (file["Beam"]["current"][:, 0] != 0).nonzero()[0]
             except:
                 ValueError("Beam monitor is needed to show mean voltage.")
-            data = [np.mean(cavity_data["cavity_phasor_record"][bunch_index,:],0), 
-                    np.mean(cavity_data["beam_phasor_record"][bunch_index,:],0),
-                    np.mean(cavity_data["generator_phasor_record"][bunch_index,:],0),
-                    np.mean(cavity_data["ig_phasor_record"][bunch_index,:],0)]
-            
+            data = [
+                np.mean(cavity_data["cavity_phasor_record"][bunch_index, :],
+                        0),
+                np.mean(cavity_data["beam_phasor_record"][bunch_index, :], 0),
+                np.mean(cavity_data["generator_phasor_record"][bunch_index, :],
+                        0),
+                np.mean(cavity_data["ig_phasor_record"][bunch_index, :], 0)
+            ]
+
         ylabel1 = labels[ph[phasor]] + units[ph[phasor]]
         ylabel2 = labels[ph[phasor]] + " phase [rad]"
-        
+
         fig, ax = plt.subplots()
         twin = ax.twinx()
-        p1, = ax.plot(time, np.abs(data[ph[phasor]])*units_val[ph[phasor]], color="r",label=ylabel1)
-        p2, = twin.plot(time, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
+        p1, = ax.plot(time,
+                      np.abs(data[ph[phasor]]) * units_val[ph[phasor]],
+                      color="r",
+                      label=ylabel1)
+        p2, = twin.plot(time,
+                        np.angle(data[ph[phasor]]),
+                        color="b",
+                        label=ylabel2)
         if show_objective:
-            o1, = ax.plot(time, np.array(cavity_data["Vc"])*1e-6, "r--", label="Cavity voltage objective value [MV]")
-            o2, = twin.plot(time, np.array(cavity_data["theta"]), "b--", label="Cavity phase objective value [rad]")
+            o1, = ax.plot(time,
+                          np.array(cavity_data["Vc"]) * 1e-6,
+                          "r--",
+                          label="Cavity voltage objective value [MV]")
+            o2, = twin.plot(time,
+                            np.array(cavity_data["theta"]),
+                            "b--",
+                            label="Cavity phase objective value [rad]")
         ax.set_xlabel("Turn number")
         ax.set_ylabel(ylabel1)
         twin.set_ylabel(ylabel2)
-        
+
         if show_objective:
             plots = [p1, o1, p2, o2]
         else:
             plots = [p1, p2]
         ax.legend(handles=plots, loc="best")
-        
+
         ax.yaxis.label.set_color("r")
         twin.yaxis.label.set_color("b")
-        
+
     if plot_type == "turn":
-        
+
         index = np.array(time) == turn
         if (index.size == 0):
             raise ValueError("Turn is not valid.")
-        data = [cavity_data["cavity_phasor_record"][:,index], 
-                cavity_data["beam_phasor_record"][:,index],
-                cavity_data["generator_phasor_record"][:,index],
-                cavity_data["ig_phasor_record"][:,index]]
-        
-        h=len(data[0])
-        x=np.arange(h)
+        data = [
+            cavity_data["cavity_phasor_record"][:, index],
+            cavity_data["beam_phasor_record"][:, index],
+            cavity_data["generator_phasor_record"][:, index],
+            cavity_data["ig_phasor_record"][:, index]
+        ]
+
+        h = len(data[0])
+        x = np.arange(h)
 
         ylabel1 = labels[ph[phasor]] + units[ph[phasor]]
         ylabel2 = labels[ph[phasor]] + " phase [rad]"
-        
+
         fig, ax = plt.subplots()
         twin = ax.twinx()
-        p1, = ax.plot(x, np.abs(data[ph[phasor]])*units_val[ph[phasor]], color="r",label=ylabel1)
-        p2, = twin.plot(x, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
+        p1, = ax.plot(x,
+                      np.abs(data[ph[phasor]]) * units_val[ph[phasor]],
+                      color="r",
+                      label=ylabel1)
+        p2, = twin.plot(x,
+                        np.angle(data[ph[phasor]]),
+                        color="b",
+                        label=ylabel2)
         ax.set_xlabel("Bunch index")
         ax.set_ylabel(ylabel1)
         twin.set_ylabel(ylabel2)
-        
+
         plots = [p1, p2]
         ax.legend(handles=plots, loc="best")
-        
+
         ax.yaxis.label.set_color("r")
         twin.yaxis.label.set_color("b")
-        
+
     if plot_type == "streak_amplitude" or plot_type == "streak_phase":
-        
-        data = [cavity_data["cavity_phasor_record"][:,:], 
-                cavity_data["beam_phasor_record"][:,:],
-                cavity_data["generator_phasor_record"][:,:],
-                cavity_data["ig_phasor_record"][:,:]]
-        
+
+        data = [
+            cavity_data["cavity_phasor_record"][:, :],
+            cavity_data["beam_phasor_record"][:, :],
+            cavity_data["generator_phasor_record"][:, :],
+            cavity_data["ig_phasor_record"][:, :]
+        ]
+
         if plot_type == "streak_amplitude":
-            data = np.transpose(np.abs(data[ph[phasor]])*units_val[ph[phasor]])
+            data = np.transpose(
+                np.abs(data[ph[phasor]]) * units_val[ph[phasor]])
             ylabel = labels[ph[phasor]] + units[ph[phasor]]
-            cmap = mpl.cm.coolwarm # diverging
+            cmap = mpl.cm.coolwarm  # diverging
         elif plot_type == "streak_phase":
             data = np.transpose(np.angle(data[ph[phasor]]))
             ylabel = labels[ph[phasor]] + " phase [rad]"
-            cmap = mpl.cm.coolwarm # diverging
-            
+            cmap = mpl.cm.coolwarm  # diverging
+
         fig, ax = plt.subplots()
-        c = ax.imshow(data, cmap=cmap, origin='lower' , aspect='auto')
+        c = ax.imshow(data, cmap=cmap, origin='lower', aspect='auto')
         if cm_lim is not None:
-            c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
+            c.set_clim(vmin=cm_lim[0], vmax=cm_lim[1])
         ax.set_xlabel("Bunch index")
         ax.set_ylabel("Number of turns")
         cbar = fig.colorbar(c, ax=ax)
         cbar.set_label(ylabel)
-        
+
     if plot_type == "detune" or plot_type == "psi":
-        
+
         fig, ax = plt.subplots()
         if plot_type == "detune":
-            data = np.array(cavity_data["detune"])*1e-3
+            data = np.array(cavity_data["detune"]) * 1e-3
             ylabel = r"Detuning $\Delta f$ [kHz]"
         elif plot_type == "psi":
             data = np.array(cavity_data["psi"])
             ylabel = r"Tuning angle $\psi$"
-            
+
         ax.plot(time, data)
         ax.set_xlabel("Number of turns")
         ax.set_ylabel(ylabel)
-        
+
     if plot_type == "power":
-        Vc = np.mean(np.abs(cavity_data["cavity_phasor_record"]),0)
-        theta = np.mean(np.angle(cavity_data["cavity_phasor_record"]),0)
+        Vc = np.mean(np.abs(cavity_data["cavity_phasor_record"]), 0)
+        theta = np.mean(np.angle(cavity_data["cavity_phasor_record"]), 0)
         try:
-            bunch_index = (file["Beam"]["current"][:,0] != 0).nonzero()[0]
-            I0 = np.nansum(file["Beam"]["current"][bunch_index,:],0)
+            bunch_index = (file["Beam"]["current"][:, 0] != 0).nonzero()[0]
+            I0 = np.nansum(file["Beam"]["current"][bunch_index, :], 0)
         except:
             print("Beam monitor is needed to compute power.")
-            
+
         Rs = np.array(cavity_data["Rs"])
-        Pc = Vc**2 / (2 * Rs)
+        Pc = Vc**2 / (2*Rs)
         Pb = I0 * Vc * np.cos(theta)
         Pg = np.array(cavity_data["Pg"])
         Pr = Pg - Pb - Pc
-        
+
         fig, ax = plt.subplots()
-        ax.plot(time, Pg*1e-3, label="Generator power $P_g$ [kW]")
-        ax.plot(time, Pb*1e-3, label="Beam power $P_b$ [kW]")
-        ax.plot(time, Pc*1e-3, label="Dissipated cavity power $P_c$ [kW]")
-        ax.plot(time, Pr*1e-3, label="Reflected power $P_r$ [kW]")
+        ax.plot(time, Pg * 1e-3, label="Generator power $P_g$ [kW]")
+        ax.plot(time, Pb * 1e-3, label="Beam power $P_b$ [kW]")
+        ax.plot(time, Pc * 1e-3, label="Dissipated cavity power $P_c$ [kW]")
+        ax.plot(time, Pr * 1e-3, label="Reflected power $P_r$ [kW]")
         ax.set_xlabel("Number of turns")
         ax.set_ylabel("Power [kW]")
         plt.legend()
-        
+
     if plot_type == "tuner_diff":
-        data0 = np.angle(cavity_data["cavity_phasor_record"][bunch_number,:])
+        data0 = np.angle(cavity_data["cavity_phasor_record"][bunch_number, :])
         data1 = np.array(cavity_data["psi"])
         data2 = np.array(cavity_data["theta_g"])
-        
+
         ylabel1 = "Tuner diff. from optimal [rad]"
         ylabel2 = r"Tuning angle $\psi$ [rad]"
         fig, ax = plt.subplots()
         twin = ax.twinx()
-        p1, = ax.plot(time, data0-data2+data1, color="r",label=ylabel1)
+        p1, = ax.plot(time, data0 - data2 + data1, color="r", label=ylabel1)
         p2, = twin.plot(time, data1, color="b", label=ylabel2)
         ax.set_xlabel("Turn number")
         ax.set_ylabel(ylabel1)
@@ -1287,6 +1459,6 @@ def plot_cavitydata(filename, cavity_name, phasor="cavity",
         ax.legend(handles=plots, loc="best")
         ax.yaxis.label.set_color("r")
         twin.yaxis.label.set_color("b")
-        
+
     file.close()
     return fig
diff --git a/mbtrack2/tracking/monitors/tools.py b/mbtrack2/tracking/monitors/tools.py
index 771ff16373f6d48b75f699874d5083e44cfb208c..743c15e7dfccca47e95bb8d3a700dbd6b52baac2 100644
--- a/mbtrack2/tracking/monitors/tools.py
+++ b/mbtrack2/tracking/monitors/tools.py
@@ -4,8 +4,9 @@ This module defines utilities functions, helping to deals with tracking output
 and hdf5 files.
 """
 
-import numpy as np
 import h5py as hp
+import numpy as np
+
 
 def merge_files(files_prefix, files_number, start_idx=0, file_name=None):
     """
@@ -34,7 +35,7 @@ def merge_files(files_prefix, files_number, start_idx=0, file_name=None):
     if file_name == None:
         file_name = files_prefix
     f = hp.File(file_name + ".hdf5", "a")
-    
+
     ## Create file architecture
     f0 = hp.File(files_prefix + "_" + str(start_idx) + ".hdf5", "r")
     for group in list(f0):
@@ -46,12 +47,12 @@ def merge_files(files_prefix, files_number, start_idx=0, file_name=None):
             shape = f0[group][dataset_name].shape
             dtype = f0[group][dataset_name].dtype
             shape_needed = list(shape)
-            shape_needed[-1] = shape_needed[-1]*files_number
+            shape_needed[-1] = shape_needed[-1] * files_number
             shape_needed = tuple(shape_needed)
             f[group].create_dataset(dataset_name, shape_needed, dtype)
-            
+
     f0.close()
-    
+
     ## Copy data
     for i, file_num in enumerate(range(start_idx, start_idx + files_number)):
         fi = hp.File(files_prefix + "_" + str(file_num) + ".hdf5", "r")
@@ -63,16 +64,19 @@ def merge_files(files_prefix, files_number, start_idx=0, file_name=None):
                 slice_list = []
                 for n in range(n_slice):
                     slice_list.append(slice(None))
-                slice_list.append(slice(length*i,length*(i+1)))
+                slice_list.append(slice(length * i, length * (i+1)))
                 if (dataset_name == "freq"):
                     continue
                 if (dataset_name == "time") and (file_num != start_idx):
-                    f[group][dataset_name][tuple(slice_list)] = np.max(f[group][dataset_name][:]) + fi[group][dataset_name]
+                    f[group][dataset_name][tuple(slice_list)] = np.max(
+                        f[group][dataset_name][:]) + fi[group][dataset_name]
                 else:
-                    f[group][dataset_name][tuple(slice_list)] = fi[group][dataset_name]
+                    f[group][dataset_name][tuple(
+                        slice_list)] = fi[group][dataset_name]
         fi.close()
     f.close()
-    
+
+
 def copy_files(source, copy, version=None):
     """
     Copy a source hdf5 file into another hdf5 file using a different HDF5 
@@ -94,14 +98,12 @@ def copy_files(source, copy, version=None):
         version = 'v108'
     f = hp.File(source + ".hdf5", "r")
     h = hp.File(copy + ".hdf5", "a", libver=('earliest', version))
-    
+
     ## Copy file
     for group in list(f):
         h.require_group(group)
         for dataset_name in list(f[group]):
             h[group][dataset_name] = f[group][dataset_name][()]
-            
+
     f.close()
     h.close()
-
-            
\ No newline at end of file
diff --git a/mbtrack2/tracking/parallel.py b/mbtrack2/tracking/parallel.py
index b71cf4d66d6f5661bee12c17a177f02c8cf689cb..58c45b116a85901e7cb173c58e0ac934805cffc7 100644
--- a/mbtrack2/tracking/parallel.py
+++ b/mbtrack2/tracking/parallel.py
@@ -1,11 +1,11 @@
 # -*- coding: utf-8 -*-
-
 """
 Module to handle parallel computation
 """
 
 import numpy as np
 
+
 class Mpi:
     """
     Class which handle parallel computation via the mpi4py module [1].
@@ -65,7 +65,7 @@ class Mpi:
         self.rank = self.comm.Get_rank()
         self.size = self.comm.Get_size()
         self.write_table(filling_pattern)
-        
+
     def write_table(self, filling_pattern):
         """
         Write a table with the rank and the corresponding bunch number for each
@@ -76,14 +76,14 @@ class Mpi:
         filling_pattern : bool array of shape (h,)
             Filling pattern of the beam, like Beam.filling_pattern
         """
-        if(filling_pattern.sum() != self.size):
+        if (filling_pattern.sum() != self.size):
             raise ValueError("The number of processors must be equal to the"
                              "number of (non-empty) bunches.")
-        table = np.zeros((self.size, 2), dtype = int)
-        table[:,0] = np.arange(0, self.size)
-        table[:,1] = np.where(filling_pattern)[0]
+        table = np.zeros((self.size, 2), dtype=int)
+        table[:, 0] = np.arange(0, self.size)
+        table[:, 1] = np.where(filling_pattern)[0]
         self.table = table
-    
+
     def rank_to_bunch(self, rank):
         """
         Return the bunch number corresponding to rank
@@ -98,8 +98,8 @@ class Mpi:
         bunch_num : int
             Bunch number corresponding to the input rank
         """
-        return self.table[rank,1]
-    
+        return self.table[rank, 1]
+
     def bunch_to_rank(self, bunch_num):
         """
         Return the rank corresponding to the bunch number bunch_num
@@ -115,33 +115,34 @@ class Mpi:
             Rank of the processor which tracks the input bunch number
         """
         try:
-            rank = np.where(self.table[:,1] == bunch_num)[0][0]
+            rank = np.where(self.table[:, 1] == bunch_num)[0][0]
         except IndexError:
-            print("The bunch " + str(bunch_num) + " is not tracked on any processor.")
+            print("The bunch " + str(bunch_num) +
+                  " is not tracked on any processor.")
             rank = None
         return rank
-    
+
     @property
     def bunch_num(self):
         """Return the bunch number corresponding to the current processor"""
         return self.rank_to_bunch(self.rank)
-    
+
     @property
     def next_bunch(self):
         """Return the rank of the next tracked bunch"""
-        if self.rank + 1 in self.table[:,0]:
+        if self.rank + 1 in self.table[:, 0]:
             return self.rank + 1
         else:
             return 0
-    
+
     @property
     def previous_bunch(self):
         """Return the rank of the previous tracked bunch"""
-        if self.rank - 1 in self.table[:,0]:
+        if self.rank - 1 in self.table[:, 0]:
             return self.rank - 1
         else:
-            return max(self.table[:,0])
-        
+            return max(self.table[:, 0])
+
     def share_distributions(self, beam, dimensions="tau", n_bin=75):
         """
         Compute the bunch profiles and share it between the different bunches.
@@ -155,44 +156,51 @@ class Mpi:
             Number of bins. The default is 75.
 
         """
-        
-        if(beam.mpi_switch == False):
+
+        if (beam.mpi_switch == False):
             print("Error, mpi is not initialised.")
-            
+
         if isinstance(dimensions, str):
             dimensions = [dimensions]
-            
+
         if isinstance(n_bin, int):
-            n_bin = np.ones((len(dimensions),), dtype=int)*n_bin
-            
+            n_bin = np.ones((len(dimensions), ), dtype=int) * n_bin
+
         bunch = beam[self.bunch_num]
-        
+
         charge_per_mp_all = self.comm.allgather(bunch.charge_per_mp)
         self.charge_per_mp_all = charge_per_mp_all
-            
+
         for i in range(len(dimensions)):
-            
+
             dim = dimensions[i]
             n = n_bin[i]
-            
+
             if len(bunch) != 0:
-                bins, sorted_index, profile, center = bunch.binning(dimension=dim, n_bin=n)
+                bins, sorted_index, profile, center = bunch.binning(
+                    dimension=dim, n_bin=n)
             else:
                 sorted_index = None
-                profile = np.zeros((n-1,),dtype=np.int64)
-                center = np.zeros((n-1,),dtype=np.float64)
+                profile = np.zeros((n - 1, ), dtype=np.int64)
+                center = np.zeros((n - 1, ), dtype=np.float64)
                 if beam.filling_pattern[self.bunch_num] is True:
                     beam.update_filling_pattern()
                     beam.update_distance_between_bunches()
-               
-            self.__setattr__(dim + "_center", np.empty((self.size, n-1), dtype=np.float64))
-            self.comm.Allgather([center,  self.MPI.DOUBLE], [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE])
-            
-            self.__setattr__(dim + "_profile", np.empty((self.size, n-1), dtype=np.int64))
-            self.comm.Allgather([profile,  self.MPI.INT64_T], [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T])
-            
+
+            self.__setattr__(dim + "_center",
+                             np.empty((self.size, n - 1), dtype=np.float64))
+            self.comm.Allgather(
+                [center, self.MPI.DOUBLE],
+                [self.__getattribute__(dim + "_center"), self.MPI.DOUBLE])
+
+            self.__setattr__(dim + "_profile",
+                             np.empty((self.size, n - 1), dtype=np.int64))
+            self.comm.Allgather(
+                [profile, self.MPI.INT64_T],
+                [self.__getattribute__(dim + "_profile"), self.MPI.INT64_T])
+
             self.__setattr__(dim + "_sorted_index", sorted_index)
-            
+
     def share_means(self, beam):
         """
         Compute the bunch means and share it between the different bunches.
@@ -202,22 +210,23 @@ class Mpi:
         beam : Beam object
 
         """
-        
-        if(beam.mpi_switch == False):
+
+        if (beam.mpi_switch == False):
             print("Error, mpi is not initialised.")
-            
+
         bunch = beam[self.bunch_num]
-        
+
         charge_all = self.comm.allgather(bunch.charge)
         self.charge_all = charge_all
-        
+
         self.mean_all = np.empty((self.size, 6), dtype=np.float64)
         if len(bunch) != 0:
             mean = bunch.mean
         else:
-            mean = np.zeros((6,), dtype=np.float64)
-        self.comm.Allgather([mean, self.MPI.DOUBLE], [self.mean_all, self.MPI.DOUBLE])
-        
+            mean = np.zeros((6, ), dtype=np.float64)
+        self.comm.Allgather([mean, self.MPI.DOUBLE],
+                            [self.mean_all, self.MPI.DOUBLE])
+
     def share_stds(self, beam):
         """
         Compute the bunch standard deviations and share it between the 
@@ -228,18 +237,18 @@ class Mpi:
         beam : Beam object
 
         """
-        if(beam.mpi_switch == False):
+        if (beam.mpi_switch == False):
             print("Error, mpi is not initialised.")
-            
+
         bunch = beam[self.bunch_num]
-        
+
         charge_all = self.comm.allgather(bunch.charge)
         self.charge_all = charge_all
-        
+
         self.std_all = np.empty((self.size, 6), dtype=np.float64)
         if len(bunch) != 0:
             std = bunch.std
         else:
-            std = np.zeros((6,), dtype=np.float64)
-        self.comm.Allgather([std, self.MPI.DOUBLE], [self.std_all, self.MPI.DOUBLE])
-                                
\ No newline at end of file
+            std = np.zeros((6, ), dtype=np.float64)
+        self.comm.Allgather([std, self.MPI.DOUBLE],
+                            [self.std_all, self.MPI.DOUBLE])
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index 356fb3d564bca9b4e27f3cc4d04c107690bcb6e0..1c3f20b12b53e40837e7c99922db83fa0cc6f096 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -3,12 +3,13 @@
 Module where particles, bunches and beams are described as objects.
 """
 
-import numpy as np
+import h5py as hp
 import matplotlib.pyplot as plt
-import seaborn as sns
+import numpy as np
 import pandas as pd
-import h5py as hp
-from scipy.constants import c, m_e, m_p, e
+import seaborn as sns
+from scipy.constants import c, e, m_e, m_p
+
 
 class Particle:
     """
@@ -29,18 +30,21 @@ class Particle:
 
     @property
     def E_rest(self):
-        return self.mass * c ** 2 / e
-    
+        return self.mass * c**2 / e
+
+
 class Electron(Particle):
     """ Define an electron"""
     def __init__(self):
-        super().__init__(m_e, -1*e)
-        
+        super().__init__(m_e, -1 * e)
+
+
 class Proton(Particle):
     """ Define a proton"""
     def __init__(self):
         super().__init__(m_p, e)
 
+
 class Bunch:
     """
     Define a bunch object.
@@ -109,118 +113,120 @@ class Bunch:
     [1] Wiedemann, H. (2015). Particle accelerator physics. 4th edition. 
     Springer, Eq.(8.39) of p224.
     """
-    
-    def __init__(self, ring, mp_number=1e3, current=1e-3, track_alive=True,
-                 alive=True, load_from_file=None, load_suffix=None):
-        
+    def __init__(self,
+                 ring,
+                 mp_number=1e3,
+                 current=1e-3,
+                 track_alive=True,
+                 alive=True,
+                 load_from_file=None,
+                 load_suffix=None):
+
         self.ring = ring
         if not alive:
             mp_number = 1
             current = 0
         self._mp_number = int(mp_number)
-        
-        self.dtype = np.dtype([('x', float),
-                       ('xp', float),
-                       ('y', float),
-                       ('yp', float),
-                       ('tau', float),
-                       ('delta', float)])
-        
+
+        self.dtype = np.dtype([('x', float), ('xp', float), ('y', float),
+                               ('yp', float), ('tau', float),
+                               ('delta', float)])
+
         self.particles = np.zeros(self.mp_number, self.dtype)
         self.track_alive = track_alive
-        self.alive = np.ones((self.mp_number,),dtype=bool)
+        self.alive = np.ones((self.mp_number, ), dtype=bool)
         self.current = current
         if not alive:
-            self.alive = np.zeros((self.mp_number,),dtype=bool)
-            
+            self.alive = np.zeros((self.mp_number, ), dtype=bool)
+
         if load_from_file is not None:
             self.load(load_from_file, load_suffix, track_alive)
-        
+
     def __len__(self):
         """Return the number of alive particles"""
         return len(self[:])
-        
+
     def __getitem__(self, label):
         """Return the columns label for alive particles"""
         if self.track_alive is True:
             return self.particles[label][self.alive]
         else:
             return self.particles[label]
-    
+
     def __setitem__(self, label, value):
         """Set value to the columns label for alive particles"""
         if self.track_alive is True:
             self.particles[label][self.alive] = value
         else:
             self.particles[label] = value
-    
+
     def __iter__(self):
         """Iterate over labels"""
         return self.dtype.names.__iter__()
-    
+
     def __repr__(self):
         """Return representation of alive particles"""
         return f'Bunch with macro-particles: \n {pd.DataFrame(self[:])!r}'
-        
+
     @property
     def mp_number(self):
         """Macro-particle number"""
         return self._mp_number
-    
+
     @mp_number.setter
     def mp_number(self, value):
         self._mp_number = int(value)
         self.__init__(self.ring, value, self.charge)
-        
+
     @property
     def charge_per_mp(self):
         """Charge per macro-particle [C]"""
         return self._charge_per_mp
-    
+
     @charge_per_mp.setter
     def charge_per_mp(self, value):
         self._charge_per_mp = value
-        
+
     @property
     def charge(self):
         """Bunch charge in [C]"""
-        return self.__len__()*self.charge_per_mp
-    
+        return self.__len__() * self.charge_per_mp
+
     @charge.setter
     def charge(self, value):
         self.charge_per_mp = value / self.__len__()
-    
+
     @property
     def particle_number(self):
         """Particle number"""
         return int(self.charge / np.abs(self.ring.particle.charge))
-    
+
     @particle_number.setter
     def particle_number(self, value):
         self.charge_per_mp = value * self.ring.particle.charge / self.__len__()
-        
+
     @property
     def current(self):
         """Bunch current [A]"""
         return self.charge / self.ring.T0
-    
+
     @current.setter
     def current(self, value):
         self.charge_per_mp = value * self.ring.T0 / self.__len__()
-        
+
     @property
     def is_empty(self):
         """Return True if the bunch is empty."""
         return ~np.any(self.alive)
-    
-    @property    
+
+    @property
     def mean(self):
         """
         Return the mean position of alive particles for each coordinates.
         """
         mean = [[self[name].mean()] for name in self]
         return np.squeeze(np.array(mean))
-    
+
     @property
     def std(self):
         """
@@ -229,21 +235,24 @@ class Bunch:
         """
         std = [[self[name].std()] for name in self]
         return np.squeeze(np.array(std))
-    
-    @property    
+
+    @property
     def emit(self):
         """
         Return the bunch emittance for each plane.
         """
         cor = np.squeeze([[self[name] - self[name].mean()] for name in self])
-        emitX = np.sqrt(np.mean(cor[0]**2)*np.mean(cor[1]**2) - 
-                        np.mean(cor[0]*cor[1])**2)
-        emitY = np.sqrt(np.mean(cor[2]**2)*np.mean(cor[3]**2) - 
-                        np.mean(cor[2]*cor[3])**2)
-        emitS = np.sqrt(np.mean(cor[4]**2)*np.mean(cor[5]**2) - 
-                        np.mean(cor[4]*cor[5])**2)
+        emitX = np.sqrt(
+            np.mean(cor[0]**2) * np.mean(cor[1]**2) -
+            np.mean(cor[0] * cor[1])**2)
+        emitY = np.sqrt(
+            np.mean(cor[2]**2) * np.mean(cor[3]**2) -
+            np.mean(cor[2] * cor[3])**2)
+        emitS = np.sqrt(
+            np.mean(cor[4]**2) * np.mean(cor[5]**2) -
+            np.mean(cor[4] * cor[5])**2)
         return np.array([emitX, emitY, emitS])
-    
+
     @property
     def cs_invariant(self):
         """
@@ -260,7 +269,7 @@ class Bunch:
               (2*self.ring.long_alpha * self['tau']*self['delta']) + \
               (self.ring.long_beta * self['delta']**2)
         return np.array((np.mean(Jx), np.mean(Jy), np.mean(Js)))
-        
+
     def init_gaussian(self, cov=None, mean=None, **kwargs):
         """
         Initialize bunch particles with 6D gaussian phase space.
@@ -282,41 +291,53 @@ class Bunch:
 
         """
         if mean is None:
-            mean = np.zeros((6,))
-        
+            mean = np.zeros((6, ))
+
         if cov is None:
             sigma_0 = kwargs.get("sigma_0", self.ring.sigma_0)
             sigma_delta = kwargs.get("sigma_delta", self.ring.sigma_delta)
             optics = kwargs.get("optics", self.ring.optics)
-            
-            cov = np.zeros((6,6))
-            cov[0,0] = self.ring.emit[0]*optics.local_beta[0] + (optics.local_dispersion[0]*self.ring.sigma_delta)**2
-            cov[1,1] = self.ring.emit[0]*optics.local_gamma[0] + (optics.local_dispersion[1]*self.ring.sigma_delta)**2
-            cov[0,1] = -1*self.ring.emit[0]*optics.local_alpha[0] + (optics.local_dispersion[0]*optics.local_dispersion[1]*self.ring.sigma_delta**2)
-            cov[1,0] = -1*self.ring.emit[0]*optics.local_alpha[0] + (optics.local_dispersion[0]*optics.local_dispersion[1]*self.ring.sigma_delta**2)
-            cov[0,5] = optics.local_dispersion[0]*self.ring.sigma_delta**2
-            cov[5,0] = optics.local_dispersion[0]*self.ring.sigma_delta**2
-            cov[1,5] = optics.local_dispersion[1]*self.ring.sigma_delta**2
-            cov[5,1] = optics.local_dispersion[1]*self.ring.sigma_delta**2
-            cov[2,2] = self.ring.emit[1]*optics.local_beta[1] + (optics.local_dispersion[2]*self.ring.sigma_delta)**2
-            cov[3,3] = self.ring.emit[1]*optics.local_gamma[1] + (optics.local_dispersion[3]*self.ring.sigma_delta)**2
-            cov[2,3] = -1*self.ring.emit[1]*optics.local_alpha[1] + (optics.local_dispersion[2]*optics.local_dispersion[3]*self.ring.sigma_delta**2)
-            cov[3,2] = -1*self.ring.emit[1]*optics.local_alpha[1] + (optics.local_dispersion[2]*optics.local_dispersion[3]*self.ring.sigma_delta**2)
-            cov[2,5] = optics.local_dispersion[2]*self.ring.sigma_delta**2
-            cov[5,2] = optics.local_dispersion[2]*self.ring.sigma_delta**2
-            cov[3,5] = optics.local_dispersion[3]*self.ring.sigma_delta**2
-            cov[5,3] = optics.local_dispersion[3]*self.ring.sigma_delta**2
-            cov[4,4] = sigma_0**2
-            cov[5,5] = sigma_delta**2
-            
+
+            cov = np.zeros((6, 6))
+            cov[0, 0] = self.ring.emit[0] * optics.local_beta[0] + (
+                optics.local_dispersion[0] * self.ring.sigma_delta)**2
+            cov[1, 1] = self.ring.emit[0] * optics.local_gamma[0] + (
+                optics.local_dispersion[1] * self.ring.sigma_delta)**2
+            cov[0, 1] = -1 * self.ring.emit[0] * optics.local_alpha[0] + (
+                optics.local_dispersion[0] * optics.local_dispersion[1] *
+                self.ring.sigma_delta**2)
+            cov[1, 0] = -1 * self.ring.emit[0] * optics.local_alpha[0] + (
+                optics.local_dispersion[0] * optics.local_dispersion[1] *
+                self.ring.sigma_delta**2)
+            cov[0, 5] = optics.local_dispersion[0] * self.ring.sigma_delta**2
+            cov[5, 0] = optics.local_dispersion[0] * self.ring.sigma_delta**2
+            cov[1, 5] = optics.local_dispersion[1] * self.ring.sigma_delta**2
+            cov[5, 1] = optics.local_dispersion[1] * self.ring.sigma_delta**2
+            cov[2, 2] = self.ring.emit[1] * optics.local_beta[1] + (
+                optics.local_dispersion[2] * self.ring.sigma_delta)**2
+            cov[3, 3] = self.ring.emit[1] * optics.local_gamma[1] + (
+                optics.local_dispersion[3] * self.ring.sigma_delta)**2
+            cov[2, 3] = -1 * self.ring.emit[1] * optics.local_alpha[1] + (
+                optics.local_dispersion[2] * optics.local_dispersion[3] *
+                self.ring.sigma_delta**2)
+            cov[3, 2] = -1 * self.ring.emit[1] * optics.local_alpha[1] + (
+                optics.local_dispersion[2] * optics.local_dispersion[3] *
+                self.ring.sigma_delta**2)
+            cov[2, 5] = optics.local_dispersion[2] * self.ring.sigma_delta**2
+            cov[5, 2] = optics.local_dispersion[2] * self.ring.sigma_delta**2
+            cov[3, 5] = optics.local_dispersion[3] * self.ring.sigma_delta**2
+            cov[5, 3] = optics.local_dispersion[3] * self.ring.sigma_delta**2
+            cov[4, 4] = sigma_0**2
+            cov[5, 5] = sigma_delta**2
+
         values = np.random.multivariate_normal(mean, cov, size=self.mp_number)
-        self.particles["x"] = values[:,0]
-        self.particles["xp"] = values[:,1]
-        self.particles["y"] = values[:,2]
-        self.particles["yp"] = values[:,3]
-        self.particles["tau"] = values[:,4]
-        self.particles["delta"] = values[:,5]
-        
+        self.particles["x"] = values[:, 0]
+        self.particles["xp"] = values[:, 1]
+        self.particles["y"] = values[:, 2]
+        self.particles["yp"] = values[:, 3]
+        self.particles["tau"] = values[:, 4]
+        self.particles["delta"] = values[:, 5]
+
     def binning(self, dimension="tau", n_bin=75):
         """
         Bin macro-particles.
@@ -341,18 +362,18 @@ class Bunch:
 
         """
         bin_min = self[dimension].min()
-        bin_min = min(bin_min*0.99, bin_min*1.01)
+        bin_min = min(bin_min * 0.99, bin_min * 1.01)
         bin_max = self[dimension].max()
-        bin_max = max(bin_max*0.99, bin_max*1.01)
-        
+        bin_max = max(bin_max * 0.99, bin_max * 1.01)
+
         bins = np.linspace(bin_min, bin_max, n_bin)
-        center = (bins[1:] + bins[:-1])/2
+        center = (bins[1:] + bins[:-1]) / 2
         sorted_index = np.searchsorted(bins, self[dimension], side='left')
         sorted_index -= 1
-        profile = np.bincount(sorted_index, minlength=n_bin-1)
-        
+        profile = np.bincount(sorted_index, minlength=n_bin - 1)
+
         return (bins, sorted_index, profile, center)
-    
+
     def plot_profile(self, dimension="tau", n_bin=75):
         """
         Plot bunch profile.
@@ -369,7 +390,7 @@ class Bunch:
         fig = plt.figure()
         ax = fig.gca()
         ax.plot(center, profile)
-        
+
     def plot_phasespace(self, x_var="tau", y_var="delta", kind="scatter"):
         """
         Plot phase space.
@@ -389,19 +410,32 @@ class Bunch:
         fig : Figure
             Figure object with the plot on it.
         """
-        
-        label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", 
-                      "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"}
-        scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1}
 
-        fig = sns.jointplot(x=self.particles[x_var]*scale[x_var],
-                            y=self.particles[y_var]*scale[y_var],
+        label_dict = {
+            "x": "x (mm)",
+            "xp": "x' (mrad)",
+            "y": "y (mm)",
+            "yp": "y' (mrad)",
+            "tau": "$\\tau$ (ps)",
+            "delta": "$\\delta$"
+        }
+        scale = {
+            "x": 1e3,
+            "xp": 1e3,
+            "y": 1e3,
+            "yp": 1e3,
+            "tau": 1e12,
+            "delta": 1
+        }
+
+        fig = sns.jointplot(x=self.particles[x_var] * scale[x_var],
+                            y=self.particles[y_var] * scale[y_var],
                             kind=kind)
         plt.xlabel(label_dict[x_var])
         plt.ylabel(label_dict[y_var])
-            
+
         return fig
-    
+
     def save(self, file_name, suffix=None, mpi_comm=None):
         """
         Save bunch object data (6D phase space, current, and state) in an HDF5 
@@ -420,30 +454,33 @@ class Bunch:
             For internal use if mpi is used in Beam objects.
             Default is None.
         """
-        
+
         if mpi_comm is None:
             f = hp.File(file_name + ".hdf5", "a", libver="earliest")
         else:
-            f = hp.File(file_name+ ".hdf5", "a", libver='earliest', 
-                        driver='mpio', comm=mpi_comm)
-        
+            f = hp.File(file_name + ".hdf5",
+                        "a",
+                        libver='earliest',
+                        driver='mpio',
+                        comm=mpi_comm)
+
         if suffix is None:
             group_name = "Bunch"
         else:
             group_name = "Bunch_" + str(suffix)
-            
+
         g = f.create_group(group_name)
-        g.create_dataset("alive", (self.mp_number,), dtype=bool)
+        g.create_dataset("alive", (self.mp_number, ), dtype=bool)
         g.create_dataset("phasespace", (self.mp_number, 6), dtype=float)
-        g.create_dataset("current", (1,), dtype=float)
-        
+        g.create_dataset("current", (1, ), dtype=float)
+
         f[group_name]["alive"][:] = self.alive
         f[group_name]["current"][:] = self.current
         for i, dim in enumerate(self):
-            f[group_name]["phasespace"][:,i] = self.particles[dim]
-        
+            f[group_name]["phasespace"][:, i] = self.particles[dim]
+
         f.close()
-        
+
     def load(self, file_name, suffix=None, track_alive=True):
         """
         Load data from a HDF5 file recorded by Bunch save method.
@@ -460,29 +497,30 @@ class Bunch:
             Should be set to True if element such as apertures are used.
             Can be set to False to gain a speed increase.
         """
-        
+
         f = hp.File(file_name, "r")
-        
+
         if suffix is None:
             group_name = "Bunch"
         else:
             group_name = "Bunch_" + str(suffix)
 
         self.mp_number = len(f[group_name]['alive'][:])
-        
+
         for i, dim in enumerate(self):
-            self.particles[dim] = f[group_name]["phasespace"][:,i]
-        
+            self.particles[dim] = f[group_name]["phasespace"][:, i]
+
         self.alive = f[group_name]['alive'][:]
         if f[group_name]['current'][:][0] != 0:
             self.current = f[group_name]['current'][:][0]
         else:
             self.charge_per_mp = 0
-            
+
         self.track_alive = track_alive
-        
+
         f.close()
-        
+
+
 class Beam:
     """
     Define a Beam object composed of several Bunch objects. 
@@ -543,42 +581,41 @@ class Beam:
     load(file_name, mpi)
         Load data from a HDF5 file recorded by Beam save method.
     """
-    
     def __init__(self, ring, bunch_list=None):
         self.ring = ring
         self.mpi_switch = False
         if bunch_list is None:
-            self.init_beam(np.zeros((self.ring.h,1),dtype=bool))
+            self.init_beam(np.zeros((self.ring.h, 1), dtype=bool))
         else:
             if (len(bunch_list) != self.ring.h):
-                raise ValueError(("The length of the bunch list is {} ".format(len(bunch_list)) + 
-                                  "but should be {}".format(self.ring.h)))
+                raise ValueError(("The length of the bunch list is {} ".format(
+                    len(bunch_list)) + "but should be {}".format(self.ring.h)))
             self.bunch_list = bunch_list
-            
+
     def __len__(self):
         """Return the number of (not empty) bunches"""
         length = 0
         for bunch in self.not_empty:
             length += 1
-        return length        
-    
+        return length
+
     def __getitem__(self, i):
         """Return the bunch number i"""
         return self.bunch_list.__getitem__(i)
-    
+
     def __setitem__(self, i, value):
         """Set value to the bunch number i"""
         self.bunch_list.__setitem__(i, value)
-    
+
     def __iter__(self):
         """Iterate over all bunches"""
         return self.bunch_list.__iter__()
-    
+
     def __repr__(self):
         """Return representation of the beam filling pattern"""
         return f'Beam with bunch current:\n {list((self.bunch_current))!r}'
-   
-    @property             
+
+    @property
     def not_empty(self):
         """Return a generator to iterate over not empty bunches."""
         for index, value in enumerate(self.filling_pattern):
@@ -586,19 +623,19 @@ class Beam:
                 yield self[index]
             else:
                 pass
-    
+
     @property
     def distance_between_bunches(self):
         """Return an array which contains the distance to the next bunch in 
         units of the RF period (ring.T1)"""
         return self._distance_between_bunches
-    
+
     def update_distance_between_bunches(self):
         """Update the distance_between_bunches array"""
         filling_pattern = self.filling_pattern
         distance = np.zeros(filling_pattern.shape)
         last_value = 0
-        
+
         # All bunches
         for index, value in enumerate(filling_pattern):
             if value == False:
@@ -606,13 +643,13 @@ class Beam:
             elif value == True:
                 last_value = index
                 count = 1
-                for value2 in filling_pattern[index+1:]:
+                for value2 in filling_pattern[index + 1:]:
                     if value2 == False:
                         count += 1
                     elif value2 == True:
                         break
                 distance[index] = count
-        
+
         # Last bunch case
         count2 = 0
         for index2, value2 in enumerate(filling_pattern):
@@ -621,11 +658,15 @@ class Beam:
             if value2 == False:
                 count2 += 1
         distance[last_value] += count2
-        
-        self._distance_between_bunches =  distance
-        
-    def init_beam(self, filling_pattern, current_per_bunch=1e-3, 
-                  mp_per_bunch=1e3, track_alive=True, mpi=False):
+
+        self._distance_between_bunches = distance
+
+    def init_beam(self,
+                  filling_pattern,
+                  current_per_bunch=1e-3,
+                  mp_per_bunch=1e3,
+                  track_alive=True,
+                  mpi=False):
         """
         Initialize beam with a given filling pattern and marco-particle number 
         per bunch. Then initialize the different bunches with a 6D gaussian
@@ -653,48 +694,51 @@ class Beam:
             If True, only a single bunch is fully initialized on each core, the
             other bunches are initialized with a single marco-particle.
         """
-        
+
         if (len(filling_pattern) != self.ring.h):
-            raise ValueError(("The length of filling pattern is {} ".format(len(filling_pattern)) + 
+            raise ValueError(("The length of filling pattern is {} ".format(
+                len(filling_pattern)) +
                               "but should be {}".format(self.ring.h)))
-        
+
         if mpi is True:
             mp_per_bunch_mpi = mp_per_bunch
             mp_per_bunch = 1
-        
+
         filling_pattern = np.array(filling_pattern)
         bunch_list = []
         if filling_pattern.dtype == np.dtype("bool"):
             for value in filling_pattern:
                 if value == True:
-                    bunch_list.append(Bunch(self.ring, mp_per_bunch, 
-                                            current_per_bunch, track_alive))
+                    bunch_list.append(
+                        Bunch(self.ring, mp_per_bunch, current_per_bunch,
+                              track_alive))
                 elif value == False:
                     bunch_list.append(Bunch(self.ring, alive=False))
         elif filling_pattern.dtype == np.dtype("float64"):
             for current in filling_pattern:
                 if current != 0:
-                    bunch_list.append(Bunch(self.ring, mp_per_bunch, 
-                                            current, track_alive))
+                    bunch_list.append(
+                        Bunch(self.ring, mp_per_bunch, current, track_alive))
                 elif current == 0:
                     bunch_list.append(Bunch(self.ring, alive=False))
         else:
-            raise TypeError("{} should be bool or float64".format(filling_pattern.dtype))
-                
+            raise TypeError("{} should be bool or float64".format(
+                filling_pattern.dtype))
+
         self.bunch_list = bunch_list
         self.update_filling_pattern()
         self.update_distance_between_bunches()
-        
+
         if mpi is True:
             self.mpi_init()
             current = self[self.mpi.rank_to_bunch(self.mpi.rank)].current
-            bunch =  Bunch(self.ring, mp_per_bunch_mpi, current, track_alive)
+            bunch = Bunch(self.ring, mp_per_bunch_mpi, current, track_alive)
             bunch.init_gaussian()
             self[self.mpi.rank_to_bunch(self.mpi.rank)] = bunch
         else:
             for bunch in self.not_empty:
                 bunch.init_gaussian()
-    
+
     def update_filling_pattern(self):
         """Update the beam filling pattern."""
         filling_pattern = []
@@ -704,114 +748,114 @@ class Beam:
             else:
                 filling_pattern.append(False)
         self._filling_pattern = np.array(filling_pattern)
-    
+
     @property
     def filling_pattern(self):
         """Return an array with the filling pattern of the beam as bool"""
         return self._filling_pattern
-    
+
     @property
     def bunch_index(self):
         """Return an array with the positions of the non-empty bunches."""
         return np.where(self.filling_pattern == True)[0]
-        
+
     @property
     def bunch_current(self):
         """Return an array with the current in each bunch in [A]"""
         bunch_current = [bunch.current for bunch in self]
         return np.array(bunch_current)
-    
+
     @property
     def bunch_charge(self):
         """Return an array with the charge in each bunch in [C]"""
         bunch_charge = [bunch.charge for bunch in self]
         return np.array(bunch_charge)
-    
+
     @property
     def bunch_particle(self):
         """Return an array with the particle number in each bunch"""
         bunch_particle = [bunch.particle_number for bunch in self]
         return np.array(bunch_particle)
-    
+
     @property
     def current(self):
         """Total beam current in [A]"""
         return np.sum(self.bunch_current)
-    
+
     @property
     def charge(self):
         """Total beam charge in [C]"""
         return np.sum(self.bunch_charge)
-    
+
     @property
     def particle_number(self):
         """Total number of particles in the beam"""
         return np.sum(self.bunch_particle)
-    
+
     @property
     def bunch_mean(self):
         """Return an array with the mean position of alive particles for each
         bunches"""
-        bunch_mean = np.zeros((6,self.ring.h))
+        bunch_mean = np.zeros((6, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
             index = self.bunch_index[idx]
-            bunch_mean[:,index] = bunch.mean
+            bunch_mean[:, index] = bunch.mean
         return bunch_mean
-    
+
     @property
     def bunch_std(self):
         """Return an array with the standard deviation of the position of alive 
         particles for each bunches"""
-        bunch_std = np.zeros((6,self.ring.h))
+        bunch_std = np.zeros((6, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
             index = self.bunch_index[idx]
-            bunch_std[:,index] = bunch.std
+            bunch_std[:, index] = bunch.std
         return bunch_std
-    
+
     @property
     def bunch_emit(self):
         """Return an array with the bunch emittance of alive particles for each
         bunches and each plane"""
-        bunch_emit = np.zeros((3,self.ring.h))
+        bunch_emit = np.zeros((3, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
             index = self.bunch_index[idx]
-            bunch_emit[:,index] = bunch.emit
+            bunch_emit[:, index] = bunch.emit
         return bunch_emit
-    
+
     @property
     def bunch_cs(self):
         """Return an array with the average Courant-Snyder invariant for each 
         bunch"""
-        bunch_cs = np.zeros((3,self.ring.h))
+        bunch_cs = np.zeros((3, self.ring.h))
         for idx, bunch in enumerate(self.not_empty):
             index = self.bunch_index[idx]
-            bunch_cs[:,index] = bunch.cs_invariant
+            bunch_cs[:, index] = bunch.cs_invariant
         return bunch_cs
-    
+
     def mpi_init(self):
         """Switch on MPI parallelisation and initialise a Mpi object"""
         from mbtrack2.tracking.parallel import Mpi
         self.mpi = Mpi(self.filling_pattern)
         self.mpi_switch = True
-        
+
     def mpi_gather(self):
         """Gather beam, all bunches of the different processors are sent to 
         all processors. Rather slow"""
-        
-        if(self.mpi_switch == False):
+
+        if (self.mpi_switch == False):
             print("Error, mpi is not initialised.")
-        
+
         bunch = self[self.mpi.bunch_num]
         bunches = self.mpi.comm.allgather(bunch)
         for rank in range(self.mpi.size):
             self[self.mpi.rank_to_bunch(rank)] = bunches[rank]
-            
+
     def mpi_close(self):
         """Call mpi_gather and switch off MPI parallelisation"""
         self.mpi_gather()
         self.mpi_switch = False
         self.mpi = None
-        
+
     def plot(self, var, option=None):
         """
         Plot variables with respect to bunch number.
@@ -834,79 +878,94 @@ class Beam:
         fig : Figure
             Figure object with the plot on it.
         """
-        
-        var_dict = {"bunch_current":self.bunch_current,
-                    "bunch_charge":self.bunch_charge,
-                    "bunch_particle":self.bunch_particle,
-                    "bunch_mean":self.bunch_mean,
-                    "bunch_std":self.bunch_std,
-                    "bunch_emit":self.bunch_emit}
-        
-        fig, ax= plt.subplots()
-        
+
+        var_dict = {
+            "bunch_current": self.bunch_current,
+            "bunch_charge": self.bunch_charge,
+            "bunch_particle": self.bunch_particle,
+            "bunch_mean": self.bunch_mean,
+            "bunch_std": self.bunch_std,
+            "bunch_emit": self.bunch_emit
+        }
+
+        fig, ax = plt.subplots()
+
         if var == "bunch_mean" or var == "bunch_std":
-            value_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+            value_dict = {
+                "x": 0,
+                "xp": 1,
+                "y": 2,
+                "yp": 3,
+                "tau": 4,
+                "delta": 5
+            }
             scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-            label_mean = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                      "$\\tau$ (ps)", "$\\delta$"]
-            label_std = ["std x (um)", "std x' ($\\mu$rad)", "std y (um)",
-                        "std y' ($\\mu$rad)", "std $\\tau$ (ps)",
-                        "std $\\delta$"]
-           
+            label_mean = [
+                "x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
+                "$\\tau$ (ps)", "$\\delta$"
+            ]
+            label_std = [
+                "std x (um)", "std x' ($\\mu$rad)", "std y (um)",
+                "std y' ($\\mu$rad)", "std $\\tau$ (ps)", "std $\\delta$"
+            ]
+
             y_axis = var_dict[var][value_dict[option]]
-            
+
             # Convert NaN in y_axis array into zero
             where_is_nan = np.isnan(y_axis)
             y_axis[where_is_nan] = 0
-            
+
             ax.plot(np.arange(len(self.filling_pattern)),
-                      y_axis*scale[value_dict[option]])
+                    y_axis * scale[value_dict[option]])
             ax.set_xlabel('bunch number')
             if var == "bunch_mean":
                 ax.set_ylabel(label_mean[value_dict[option]])
-            else: 
+            else:
                 ax.set_ylabel(label_std[value_dict[option]])
-            
+
         elif var == "bunch_emit":
-            value_dict = {"x":0, "y":1, "s":2}
+            value_dict = {"x": 0, "y": 1, "s": 2}
             scale = [1e9, 1e9, 1e15]
-            
+
             y_axis = var_dict[var][value_dict[option]]
-            
+
             # Convert NaN in y_axis array into zero
             where_is_nan = np.isnan(y_axis)
             y_axis[where_is_nan] = 0
-            
-            ax.plot(np.arange(len(self.filling_pattern)), 
-                     y_axis*scale[value_dict[option]])
-            
+
+            ax.plot(np.arange(len(self.filling_pattern)),
+                    y_axis * scale[value_dict[option]])
+
             if option == "x": label_y = "hor. emittance (nm.rad)"
             elif option == "y": label_y = "ver. emittance (nm.rad)"
-            elif option == "s": label_y =  "long. emittance (fm.rad)"
-            
+            elif option == "s": label_y = "long. emittance (fm.rad)"
+
             ax.set_xlabel('bunch number')
             ax.set_ylabel(label_y)
-                
-        elif var=="bunch_current" or var=="bunch_charge" or var=="bunch_particle":
-            scale = {"bunch_current":1e3, "bunch_charge":1e9, 
-                     "bunch_particle":1}
-            
-            ax.plot(np.arange(len(self.filling_pattern)), var_dict[var]*
-                     scale[var]) 
+
+        elif var == "bunch_current" or var == "bunch_charge" or var == "bunch_particle":
+            scale = {
+                "bunch_current": 1e3,
+                "bunch_charge": 1e9,
+                "bunch_particle": 1
+            }
+
+            ax.plot(np.arange(len(self.filling_pattern)),
+                    var_dict[var] * scale[var])
             ax.set_xlabel('bunch number')
-            
+
             if var == "bunch_current": label_y = "bunch current (mA)"
             elif var == "bunch_charge": label_y = "bunch chagre (nC)"
             else: label_y = "number of particles"
 
-            ax.set_ylabel(label_y)             
-    
-        elif var == "current" or var=="charge" or var=="particle_number":
-            raise ValueError("'{0}'is a total value and cannot be plotted."
-                             .format(var))
-       
+            ax.set_ylabel(label_y)
+
+        elif var == "current" or var == "charge" or var == "particle_number":
+            raise ValueError(
+                "'{0}'is a total value and cannot be plotted.".format(var))
+
         return fig
-        
+
     def save(self, file_name):
         """
         Save beam object data in an HDF5 file format.
@@ -920,32 +979,35 @@ class Beam:
         """
         if self.mpi_switch is True:
             for i, bunch in enumerate(self):
-                if i in self.mpi.table[:,1]:
+                if i in self.mpi.table[:, 1]:
                     if i == self.mpi.bunch_num:
                         mp_number = self[self.mpi.bunch_num].mp_number
                         self.mpi.comm.bcast(mp_number, root=self.mpi.rank)
-                        self[self.mpi.bunch_num].save(file_name, 
-                                                      self.mpi.bunch_num, 
+                        self[self.mpi.bunch_num].save(file_name,
+                                                      self.mpi.bunch_num,
                                                       self.mpi.comm)
                     else:
                         mp_number = None
-                        mp_number = self.mpi.comm.bcast(mp_number, root=self.mpi.bunch_to_rank(i))
-                        f = hp.File(file_name+ ".hdf5", "a", libver='earliest', 
-                                    driver='mpio', comm=self.mpi.comm)
+                        mp_number = self.mpi.comm.bcast(
+                            mp_number, root=self.mpi.bunch_to_rank(i))
+                        f = hp.File(file_name + ".hdf5",
+                                    "a",
+                                    libver='earliest',
+                                    driver='mpio',
+                                    comm=self.mpi.comm)
                         group_name = "Bunch_" + str(i)
                         g = f.create_group(group_name)
-                        g.create_dataset("alive", (mp_number,), dtype=bool)
-                        g.create_dataset("phasespace", (mp_number, 6), dtype=float)
-                        g.create_dataset("current", (1,), dtype=float)  
+                        g.create_dataset("alive", (mp_number, ), dtype=bool)
+                        g.create_dataset("phasespace", (mp_number, 6),
+                                         dtype=float)
+                        g.create_dataset("current", (1, ), dtype=float)
                         f.close()
                 else:
-                    bunch.save(file_name, 
-                               i,
-                               self.mpi.comm)
+                    bunch.save(file_name, i, self.mpi.comm)
         else:
             for i, bunch in enumerate(self):
                 bunch.save(file_name, i)
-    
+
     def load(self, file_name, mpi, track_alive=True):
         """
         Load data from a HDF5 file recorded by Beam save method.
@@ -970,7 +1032,7 @@ class Beam:
             for i in range(self.ring.h):
                 current = f["Bunch_" + str(i)]['current'][:][0]
                 filling_pattern.append(current)
-                    
+
             self.init_beam(filling_pattern,
                            mp_per_bunch=1,
                            track_alive=track_alive,
@@ -980,4 +1042,4 @@ class Beam:
             for i, bunch in enumerate(self):
                 bunch.load(file_name, i, track_alive)
             self.update_filling_pattern()
-            self.update_distance_between_bunches()
\ No newline at end of file
+            self.update_distance_between_bunches()
diff --git a/mbtrack2/tracking/rf.py b/mbtrack2/tracking/rf.py
index 7f38c0bb0d17c35bb18ffa649a77d3f6f0be2774..fedc7ab0e548c91f444b216be95b3cb1d7149ae7 100644
--- a/mbtrack2/tracking/rf.py
+++ b/mbtrack2/tracking/rf.py
@@ -3,12 +3,14 @@
 This module handles radio-frequency (RF) cavitiy elements. 
 """
 
-import numpy as np
-import matplotlib.pyplot as plt
 import matplotlib.patches as mpatches
+import matplotlib.pyplot as plt
+import numpy as np
 from matplotlib.legend_handler import HandlerPatch
+
 from mbtrack2.tracking.element import Element
 
+
 class RFCavity(Element):
     """
     Perfect RF cavity class for main and harmonic RF cavities.
@@ -26,12 +28,12 @@ class RFCavity(Element):
     """
     def __init__(self, ring, m, Vc, theta):
         self.ring = ring
-        self.m = m 
+        self.m = m
         self.Vc = Vc
         self.theta = theta
-        
-    @Element.parallel    
-    def track(self,bunch):
+
+    @Element.parallel
+    def track(self, bunch):
         """
         Tracking method for the element.
         No bunch to bunch interaction, so written for Bunch objects and
@@ -42,13 +44,13 @@ class RFCavity(Element):
         bunch : Bunch or Beam object
         """
         bunch["delta"] += self.Vc / self.ring.E0 * np.cos(
-                self.m * self.ring.omega1 * bunch["tau"] + self.theta )
-        
+            self.m * self.ring.omega1 * bunch["tau"] + self.theta)
+
     def value(self, val):
-        return self.Vc / self.ring.E0 * np.cos( 
-                self.m * self.ring.omega1 * val + self.theta )
-    
-    
+        return self.Vc / self.ring.E0 * np.cos(self.m * self.ring.omega1 *
+                                               val + self.theta)
+
+
 class CavityResonator():
     """Cavity resonator class for active or passive RF cavity with beam
     loading or HOM, based on [1,2].
@@ -198,7 +200,16 @@ class CavityResonator():
     of Longitudinal Beam Dynamics With Harmonic Cavities by Using the Code 
     Mbtrack." IPAC’19, Melbourne, Australia, 2019.
     """
-    def __init__(self, ring, m, Rs, Q, QL, detune, Ncav=1, Vc=0, theta=0, 
+    def __init__(self,
+                 ring,
+                 m,
+                 Rs,
+                 Q,
+                 QL,
+                 detune,
+                 Ncav=1,
+                 Vc=0,
+                 theta=0,
                  n_bin=75):
         self.ring = ring
         self.feedback = []
@@ -223,7 +234,7 @@ class CavityResonator():
         self.theta_gr = 0
         self.Pg = 0
         self.n_bin = int(n_bin)
-        
+
     def init_tracking(self, beam):
         """
         Initialization of the tracking.
@@ -234,13 +245,13 @@ class CavityResonator():
 
         """
         if beam.mpi_switch:
-            self.bunch_index = beam.mpi.bunch_num # Number of the tracked bunch in this processor
-            
+            self.bunch_index = beam.mpi.bunch_num  # Number of the tracked bunch in this processor
+
         self.distance = beam.distance_between_bunches
         self.valid_bunch_index = beam.bunch_index
         self.tracking = True
         self.nturn = 0
-    
+
     def track(self, beam):
         """
         Track a Beam object through the CavityResonator object.
@@ -256,29 +267,30 @@ class CavityResonator():
         beam : Beam object
 
         """
-        
+
         if self.tracking is False:
             self.init_tracking(beam)
-        
+
         for index, bunch in enumerate(beam):
-            
+
             if beam.filling_pattern[index]:
-                
+
                 if beam.mpi_switch:
                     # get rank of bunch n° index
                     rank = beam.mpi.bunch_to_rank(index)
                     # mpi -> get shared bunch profile for current bunch
                     center = beam.mpi.tau_center[rank]
                     profile = beam.mpi.tau_profile[rank]
-                    bin_length = center[1]-center[0]
+                    bin_length = center[1] - center[0]
                     charge_per_mp = beam.mpi.charge_per_mp_all[rank]
                     if index == self.bunch_index:
                         sorted_index = beam.mpi.tau_sorted_index
                 else:
                     # no mpi -> get bunch profile for current bunch
                     if len(bunch) != 0:
-                        (bins, sorted_index, profile, center) = bunch.binning(n_bin=self.n_bin)
-                        bin_length = center[1]-center[0]
+                        (bins, sorted_index, profile,
+                         center) = bunch.binning(n_bin=self.n_bin)
+                        bin_length = center[1] - center[0]
                         charge_per_mp = bunch.charge_per_mp
                         self.bunch_index = index
                     else:
@@ -290,52 +302,59 @@ class CavityResonator():
                         # phasor decay to be at t=0 of the next bunch
                         self.phasor_decay(self.ring.T1, ref_frame="beam")
                         continue
-                
-                energy_change = bunch["tau"]*0
-                
+
+                energy_change = bunch["tau"] * 0
+
                 # remove part of beam phasor decay to be at the start of the binning (=bins[0])
                 self.phasor_decay(center[0] - bin_length/2, ref_frame="beam")
-                
+
                 if index != self.bunch_index:
-                    self.phasor_evol(profile, bin_length, charge_per_mp, ref_frame="beam")
+                    self.phasor_evol(profile,
+                                     bin_length,
+                                     charge_per_mp,
+                                     ref_frame="beam")
                 else:
                     # modify beam phasor
                     for i, center0 in enumerate(center):
                         mp_per_bin = profile[i]
-                        
+
                         if mp_per_bin == 0:
                             self.phasor_decay(bin_length, ref_frame="beam")
                             continue
-                        
+
                         ind = (sorted_index == i)
-                        phase = self.m * self.ring.omega1 * (center0 + self.ring.T1* (index + self.ring.h * self.nturn))
-                        Vgene = np.real(self.generator_phasor_record[index]*np.exp(1j*phase))
+                        phase = self.m * self.ring.omega1 * (
+                            center0 + self.ring.T1 *
+                            (index + self.ring.h * self.nturn))
+                        Vgene = np.real(self.generator_phasor_record[index] *
+                                        np.exp(1j * phase))
                         Vbeam = np.real(self.beam_phasor)
-                        Vtot = Vgene + Vbeam - charge_per_mp*self.loss_factor*mp_per_bin
+                        Vtot = Vgene + Vbeam - charge_per_mp * self.loss_factor * mp_per_bin
                         energy_change[ind] = Vtot / self.ring.E0
-    
-                        self.beam_phasor -= 2*charge_per_mp*self.loss_factor*mp_per_bin
+
+                        self.beam_phasor -= 2 * charge_per_mp * self.loss_factor * mp_per_bin
                         self.phasor_decay(bin_length, ref_frame="beam")
-                
+
                 # phasor decay to be at t=0 of the current bunch (=-1*bins[-1])
-                self.phasor_decay(-1 * (center[-1] + bin_length/2), ref_frame="beam")
-                
+                self.phasor_decay(-1 * (center[-1] + bin_length/2),
+                                  ref_frame="beam")
+
                 if index == self.bunch_index:
                     # apply kick
                     bunch["delta"] += energy_change
-            
+
             # save beam phasor value
             self.beam_phasor_record[index] = self.beam_phasor
-            
+
             # phasor decay to be at t=0 of the next bunch
             self.phasor_decay(self.ring.T1, ref_frame="beam")
-            
+
         # apply different kind of RF feedback
         for fb in self.feedback:
             fb.track()
-                
+
         self.nturn += 1
-                
+
     def init_phasor_track(self, beam):
         """
         Initialize the beam phasor for a given beam distribution using a
@@ -348,44 +367,56 @@ class CavityResonator():
         ----------
         beam : Beam object
 
-        """        
+        """
         if self.tracking is False:
             self.init_tracking(beam)
-            
-        n_turn = int(self.filling_time/self.ring.T0*10)
-        
+
+        n_turn = int(self.filling_time / self.ring.T0 * 10)
+
         for i in range(n_turn):
             for j, bunch in enumerate(beam.not_empty):
-                
+
                 index = self.valid_bunch_index[j]
-                
+
                 if beam.mpi_switch:
                     # get shared bunch profile for current bunch
                     center = beam.mpi.tau_center[j]
                     profile = beam.mpi.tau_profile[j]
-                    bin_length = center[1]-center[0]
+                    bin_length = center[1] - center[0]
                     charge_per_mp = beam.mpi.charge_per_mp_all[j]
                 else:
                     if i == 0:
                         # get bunch profile for current bunch
-                        (bins, sorted_index, profile, center) = bunch.binning(n_bin=self.n_bin)
+                        (bins, sorted_index, profile,
+                         center) = bunch.binning(n_bin=self.n_bin)
                         if j == 0:
-                            self.profile_save = np.zeros((len(beam),len(profile),))
-                            self.center_save = np.zeros((len(beam),len(center),))
-                        self.profile_save[j,:] = profile
-                        self.center_save[j,:] = center
+                            self.profile_save = np.zeros((
+                                len(beam),
+                                len(profile),
+                            ))
+                            self.center_save = np.zeros((
+                                len(beam),
+                                len(center),
+                            ))
+                        self.profile_save[j, :] = profile
+                        self.center_save[j, :] = center
                     else:
-                        profile = self.profile_save[j,:]
-                        center = self.center_save[j,:]
-                        
-                    bin_length = center[1]-center[0]
+                        profile = self.profile_save[j, :]
+                        center = self.center_save[j, :]
+
+                    bin_length = center[1] - center[0]
                     charge_per_mp = bunch.charge_per_mp
-                
+
                 self.phasor_decay(center[0] - bin_length/2, ref_frame="rf")
-                self.phasor_evol(profile, bin_length, charge_per_mp, ref_frame="rf")
-                self.phasor_decay(-1 * (center[-1] + bin_length/2), ref_frame="rf")
-                self.phasor_decay( (self.distance[index] * self.ring.T1), ref_frame="rf")
-            
+                self.phasor_evol(profile,
+                                 bin_length,
+                                 charge_per_mp,
+                                 ref_frame="rf")
+                self.phasor_decay(-1 * (center[-1] + bin_length/2),
+                                  ref_frame="rf")
+                self.phasor_decay((self.distance[index] * self.ring.T1),
+                                  ref_frame="rf")
+
     def phasor_decay(self, time, ref_frame="beam"):
         """
         Compute the beam phasor decay during a given time span, assuming that 
@@ -402,11 +433,15 @@ class CavityResonator():
         if ref_frame == "beam":
             delta = self.wr
         elif ref_frame == "rf":
-            delta = (self.wr - self.m*self.ring.omega1)
-        self.beam_phasor = self.beam_phasor * np.exp((-1/self.filling_time +
-                                  1j*delta)*time)
-        
-    def phasor_evol(self, profile, bin_length, charge_per_mp, ref_frame="beam"):
+            delta = (self.wr - self.m * self.ring.omega1)
+        self.beam_phasor = self.beam_phasor * np.exp(
+            (-1 / self.filling_time + 1j*delta) * time)
+
+    def phasor_evol(self,
+                    profile,
+                    bin_length,
+                    charge_per_mp,
+                    ref_frame="beam"):
         """
         Compute the beam phasor evolution during the crossing of a bunch using 
         an analytic formula [1].
@@ -432,22 +467,22 @@ class CavityResonator():
         if ref_frame == "beam":
             delta = self.wr
         elif ref_frame == "rf":
-            delta = (self.wr - self.m*self.ring.omega1)
-            
+            delta = (self.wr - self.m * self.ring.omega1)
+
         n_bin = len(profile)
-        
+
         # Phasor decay during crossing time
-        deltaT = n_bin*bin_length
+        deltaT = n_bin * bin_length
         self.phasor_decay(deltaT, ref_frame)
-        
+
         # Phasor evolution due to induced voltage by marco-particles
         k = np.arange(0, n_bin)
-        var = np.exp( (-1/self.filling_time + 1j*delta) * 
-                      (n_bin-k) * bin_length )
+        var = np.exp(
+            (-1 / self.filling_time + 1j*delta) * (n_bin-k) * bin_length)
         sum_tot = np.sum(profile * var)
         sum_val = -2 * sum_tot * charge_per_mp * self.loss_factor
         self.beam_phasor += sum_val
-        
+
     def init_phasor(self, beam):
         """
         Initialize the beam phasor for a given beam distribution using an
@@ -464,49 +499,51 @@ class CavityResonator():
         [1] mbtrack2 manual.
 
         """
-        
+
         # Initialization
         if self.tracking is False:
             self.init_tracking(beam)
-        
+
         N = self.n_bin - 1
-        delta = (self.wr - self.m*self.ring.omega1)
-        n_turn = int(self.filling_time/self.ring.T0*10)
-        
-        T = np.ones(self.ring.h)*self.ring.T1
+        delta = (self.wr - self.m * self.ring.omega1)
+        n_turn = int(self.filling_time / self.ring.T0 * 10)
+
+        T = np.ones(self.ring.h) * self.ring.T1
         bin_length = np.zeros(self.ring.h)
         charge_per_mp = np.zeros(self.ring.h)
         profile = np.zeros((N, self.ring.h))
         center = np.zeros((N, self.ring.h))
-        
+
         # Gather beam distribution data
         for j, bunch in enumerate(beam.not_empty):
             index = self.valid_bunch_index[j]
             if beam.mpi_switch:
                 beam.mpi.share_distributions(beam, n_bin=self.n_bin)
-                center[:,index] = beam.mpi.tau_center[j]
-                profile[:,index] = beam.mpi.tau_profile[j]
-                bin_length[index] = center[1, index]-center[0, index]
+                center[:, index] = beam.mpi.tau_center[j]
+                profile[:, index] = beam.mpi.tau_profile[j]
+                bin_length[index] = center[1, index] - center[0, index]
                 charge_per_mp[index] = beam.mpi.charge_per_mp_all[j]
             else:
-                (bins, sorted_index, profile[:, index], center[:, index]) = bunch.binning(n_bin=self.n_bin)
-                bin_length[index] = center[1, index]-center[0, index]
+                (bins, sorted_index, profile[:, index],
+                 center[:, index]) = bunch.binning(n_bin=self.n_bin)
+                bin_length[index] = center[1, index] - center[0, index]
                 charge_per_mp[index] = bunch.charge_per_mp
-            T[index] -= (center[-1, index] + bin_length[index]/2)
+            T[index] -= (center[-1, index] + bin_length[index] / 2)
             if index != 0:
-                T[index - 1] += (center[0, index] - bin_length[index]/2)
-        T[self.ring.h - 1] += (center[0, 0] - bin_length[0]/2)
+                T[index - 1] += (center[0, index] - bin_length[index] / 2)
+        T[self.ring.h - 1] += (center[0, 0] - bin_length[0] / 2)
 
         # Compute matrix coefficients
         k = np.arange(0, N)
         Tkj = np.zeros((N, self.ring.h))
         for j in range(self.ring.h):
-            sum_t = np.array([T[n] + N*bin_length[n] for n in range(j+1,self.ring.h)])
-            Tkj[:,j] = (N-k)*bin_length[j] + T[j] + np.sum(sum_t)
-            
-        var = np.exp( (-1/self.filling_time + 1j*delta) * Tkj )
+            sum_t = np.array(
+                [T[n] + N * bin_length[n] for n in range(j + 1, self.ring.h)])
+            Tkj[:, j] = (N-k) * bin_length[j] + T[j] + np.sum(sum_t)
+
+        var = np.exp((-1 / self.filling_time + 1j*delta) * Tkj)
         sum_tot = np.sum((profile*charge_per_mp) * var)
-        
+
         # Use the formula n_turn times
         for i in range(n_turn):
             # Phasor decay during one turn
@@ -514,26 +551,27 @@ class CavityResonator():
             # Phasor evolution due to induced voltage by marco-particles during one turn
             sum_val = -2 * sum_tot * self.loss_factor
             self.beam_phasor += sum_val
-        
+
         # Replace phasor at t=0 (synchronous particle) of the first non empty bunch.
         idx0 = self.valid_bunch_index[0]
-        self.phasor_decay(center[-1,idx0] + bin_length[idx0]/2, ref_frame="rf")
-    
+        self.phasor_decay(center[-1, idx0] + bin_length[idx0] / 2,
+                          ref_frame="rf")
+
     @property
     def generator_phasor(self):
         """Generator phasor in [V]"""
-        return self.Vg*np.exp(1j*self.theta_g)
-    
+        return self.Vg * np.exp(1j * self.theta_g)
+
     @property
     def cavity_phasor(self):
         """Cavity total phasor in [V]"""
         return self.generator_phasor + self.beam_phasor
-    
+
     @property
     def cavity_phasor_record(self):
         """Last cavity phasor value of each bunch in [V]"""
         return self.generator_phasor_record + self.beam_phasor_record
-    
+
     @property
     def ig_phasor_record(self):
         """Last current generator phasor of each bunch in [A]"""
@@ -541,31 +579,31 @@ class CavityResonator():
             if isinstance(FB, (ProportionalIntegralLoop, DirectFeedback)):
                 return FB.ig_phasor_record
         return np.zeros(self.ring.h)
-    
+
     @property
     def cavity_voltage(self):
         """Cavity total voltage in [V]"""
         return np.abs(self.cavity_phasor)
-    
+
     @property
     def cavity_phase(self):
         """Cavity total phase in [rad]"""
         return np.angle(self.cavity_phasor)
-    
+
     @property
     def beam_voltage(self):
         """Beam loading voltage in [V]"""
         return np.abs(self.beam_phasor)
-    
+
     @property
     def beam_phase(self):
         """Beam loading phase in [rad]"""
         return np.angle(self.beam_phasor)
-    
+
     @property
     def loss_factor(self):
         """Cavity loss factor in [V/C]"""
-        return self.wr*self.Rs/(2 * self.Q)
+        return self.wr * self.Rs / (2 * self.Q)
 
     @property
     def m(self):
@@ -575,7 +613,7 @@ class CavityResonator():
     @m.setter
     def m(self, value):
         self._m = value
-        
+
     @property
     def Ncav(self):
         """Number of cavities"""
@@ -584,7 +622,7 @@ class CavityResonator():
     @Ncav.setter
     def Ncav(self, value):
         self._Ncav = value
-        
+
     @property
     def Rs_per_cavity(self):
         """Shunt impedance of a single cavity in [Ohm], defined as 
@@ -603,7 +641,7 @@ class CavityResonator():
     @Rs.setter
     def Rs(self, value):
         self.Rs_per_cavity = value / self.Ncav
-        
+
     @property
     def RL(self):
         """Loaded shunt impedance [ohm]"""
@@ -626,7 +664,7 @@ class CavityResonator():
     @QL.setter
     def QL(self, value):
         self._QL = value
-        self._beta = self.Q/self.QL - 1
+        self._beta = self.Q / self.QL - 1
         self.update_feedback()
 
     @property
@@ -636,7 +674,7 @@ class CavityResonator():
 
     @beta.setter
     def beta(self, value):
-        self.QL = self.Q/(1 + value)
+        self.QL = self.Q / (1+value)
 
     @property
     def detune(self):
@@ -646,10 +684,10 @@ class CavityResonator():
     @detune.setter
     def detune(self, value):
         self._detune = value
-        self._fr = self.detune + self.m*self.ring.f1
-        self._wr = self.fr*2*np.pi
-        self._psi = np.arctan(self.QL*(self.fr/(self.m*self.ring.f1) -
-                                       (self.m*self.ring.f1)/self.fr))
+        self._fr = self.detune + self.m * self.ring.f1
+        self._wr = self.fr * 2 * np.pi
+        self._psi = np.arctan(self.QL * (self.fr / (self.m * self.ring.f1) -
+                                         (self.m * self.ring.f1) / self.fr))
         self.update_feedback()
 
     @property
@@ -659,7 +697,7 @@ class CavityResonator():
 
     @fr.setter
     def fr(self, value):
-        self.detune = value - self.m*self.ring.f1
+        self.detune = value - self.m * self.ring.f1
 
     @property
     def wr(self):
@@ -668,7 +706,7 @@ class CavityResonator():
 
     @wr.setter
     def wr(self, value):
-        self.detune = (value - self.m*self.ring.f1)*2*np.pi
+        self.detune = (value - self.m * self.ring.f1) * 2 * np.pi
 
     @property
     def psi(self):
@@ -677,20 +715,22 @@ class CavityResonator():
 
     @psi.setter
     def psi(self, value):
-        delta = (self.ring.f1*self.m*np.tan(value)/self.QL)**2 + 4*(self.ring.f1*self.m)**2
-        fr = (self.ring.f1*self.m*np.tan(value)/self.QL + np.sqrt(delta))/2
-        self.detune = fr - self.m*self.ring.f1
-        
+        delta = (self.ring.f1 * self.m * np.tan(value) /
+                 self.QL)**2 + 4 * (self.ring.f1 * self.m)**2
+        fr = (self.ring.f1 * self.m * np.tan(value) / self.QL +
+              np.sqrt(delta)) / 2
+        self.detune = fr - self.m * self.ring.f1
+
     @property
     def filling_time(self):
         """Cavity filling time in [s]"""
-        return 2*self.QL/self.wr
-    
+        return 2 * self.QL / self.wr
+
     @property
     def Pc(self):
         """Power dissipated in the cavity walls in [W]"""
         return self.Vc**2 / (2 * self.Rs)
-    
+
     def Pb(self, I0):
         """
         Return power transmitted to the beam in [W] - near Eq. (4.2.3) in [1].
@@ -707,7 +747,7 @@ class CavityResonator():
 
         """
         return I0 * self.Vc * np.cos(self.theta)
-    
+
     def Pr(self, I0):
         """
         Power reflected back to the generator in [W].
@@ -740,8 +780,8 @@ class CavityResonator():
             Beam voltage at resonance in [V].
 
         """
-        return 2*I0*self.Rs/(1+self.beta)
-    
+        return 2 * I0 * self.Rs / (1 + self.beta)
+
     def Vb(self, I0):
         """
         Return beam voltage in [V].
@@ -757,12 +797,12 @@ class CavityResonator():
             Beam voltage in [V].
 
         """
-        return self.Vbr(I0)*np.cos(self.psi)
-    
+        return self.Vbr(I0) * np.cos(self.psi)
+
     def Z(self, f):
         """Cavity impedance in [Ohm] for a given frequency f in [Hz]"""
-        return self.RL/(1 + 1j*self.QL*(self.fr/f - f/self.fr))
-    
+        return self.RL / (1 + 1j * self.QL * (self.fr / f - f / self.fr))
+
     def set_optimal_detune(self, I0):
         """
         Set detuning to optimal conditions - second Eq. (4.2.1) in [1].
@@ -773,8 +813,8 @@ class CavityResonator():
             Beam current in [A].
 
         """
-        self.psi = np.arctan(-self.Vbr(I0)/self.Vc*np.sin(self.theta))
-        
+        self.psi = np.arctan(-self.Vbr(I0) / self.Vc * np.sin(self.theta))
+
     def set_optimal_coupling(self, I0):
         """
         Set coupling to optimal value - Eq. (4.2.3) in [1].
@@ -785,9 +825,8 @@ class CavityResonator():
             Beam current in [A].
 
         """
-        self.beta = 1 + (2 * I0 * self.Rs * np.cos(self.theta) / 
-                         self.Vc)
-                
+        self.beta = 1 + (2 * I0 * self.Rs * np.cos(self.theta) / self.Vc)
+
     def set_generator(self, I0):
         """
         Set generator parameters (Pg, Vgr, theta_gr, Vg and theta_g) for a 
@@ -799,23 +838,32 @@ class CavityResonator():
             Beam current in [A].
 
         """
-        
+
         # Generator power [W] - Eq. (4.1.2) [1] corrected with factor (1+beta)**2 instead of (1+beta**2)
-        self.Pg = self.Vc**2*(1+self.beta)**2/(2*self.Rs*4*self.beta*np.cos(self.psi)**2)*(
-            (np.cos(self.theta) + 2*I0*self.Rs/(self.Vc*(1+self.beta))*np.cos(self.psi)**2 )**2 + 
-            (np.sin(self.theta) + 2*I0*self.Rs/(self.Vc*(1+self.beta))*np.cos(self.psi)*np.sin(self.psi) )**2)
+        self.Pg = self.Vc**2 * (1 + self.beta)**2 / (
+            2 * self.Rs * 4 * self.beta * np.cos(self.psi)**2) * (
+                (np.cos(self.theta) + 2 * I0 * self.Rs /
+                 (self.Vc * (1 + self.beta)) * np.cos(self.psi)**2)**2 +
+                (np.sin(self.theta) + 2 * I0 * self.Rs /
+                 (self.Vc *
+                  (1 + self.beta)) * np.cos(self.psi) * np.sin(self.psi))**2)
         # Generator voltage at resonance [V] - Eq. (3.2.2) [1]
-        self.Vgr = 2*self.beta**(1/2)/(1+self.beta)*(2*self.Rs*self.Pg)**(1/2)
+        self.Vgr = 2 * self.beta**(1 / 2) / (1 + self.beta) * (
+            2 * self.Rs * self.Pg)**(1 / 2)
         # Generator phase at resonance [rad] - from Eq. (4.1.1)
-        self.theta_gr = np.arctan((self.Vc*np.sin(self.theta) + self.Vbr(I0)*np.cos(self.psi)*np.sin(self.psi))/
-                    (self.Vc*np.cos(self.theta) + self.Vbr(I0)*np.cos(self.psi)**2)) - self.psi
+        self.theta_gr = np.arctan(
+            (self.Vc * np.sin(self.theta) +
+             self.Vbr(I0) * np.cos(self.psi) * np.sin(self.psi)) /
+            (self.Vc * np.cos(self.theta) +
+             self.Vbr(I0) * np.cos(self.psi)**2)) - self.psi
         # Generator voltage [V]
-        self.Vg = self.Vgr*np.cos(self.psi)
+        self.Vg = self.Vgr * np.cos(self.psi)
         # Generator phase [rad]
         self.theta_g = self.theta_gr + self.psi
         # Set generator_phasor_record
-        self.generator_phasor_record = np.ones(self.ring.h)*self.generator_phasor
-        
+        self.generator_phasor_record = np.ones(
+            self.ring.h) * self.generator_phasor
+
     def plot_phasor(self, I0):
         """
         Plot phasor diagram showing the vector addition of generator and beam 
@@ -831,30 +879,57 @@ class CavityResonator():
         Figure.
 
         """
-
-        def make_legend_arrow(legend, orig_handle,
-                              xdescent, ydescent,
-                              width, height, fontsize):
-            p = mpatches.FancyArrow(0, 0.5*height, width, 0, length_includes_head=True, head_width=0.75*height )
+        def make_legend_arrow(legend, orig_handle, xdescent, ydescent, width,
+                              height, fontsize):
+            p = mpatches.FancyArrow(0,
+                                    0.5 * height,
+                                    width,
+                                    0,
+                                    length_includes_head=True,
+                                    head_width=0.75 * height)
             return p
 
         fig = plt.figure()
-        ax= fig.add_subplot(111, polar=True)
-        ax.set_rmax(max([1.2,self.Vb(I0)/self.Vc*1.2,self.Vg/self.Vc*1.2]))
-        arr1 = ax.arrow(self.theta, 0, 0, 1, alpha = 0.5, width = 0.015,
-                         edgecolor = 'black', lw = 2)
-
-        arr2 = ax.arrow(self.psi + np.pi, 0, 0,self.Vb(I0)/self.Vc, alpha = 0.5, width = 0.015,
-                         edgecolor = 'red', lw = 2)
-
-        arr3 = ax.arrow(self.theta_g, 0, 0,self.Vg/self.Vc, alpha = 0.5, width = 0.015,
-                         edgecolor = 'blue', lw = 2)
+        ax = fig.add_subplot(111, polar=True)
+        ax.set_rmax(
+            max([1.2,
+                 self.Vb(I0) / self.Vc * 1.2, self.Vg / self.Vc * 1.2]))
+        arr1 = ax.arrow(self.theta,
+                        0,
+                        0,
+                        1,
+                        alpha=0.5,
+                        width=0.015,
+                        edgecolor='black',
+                        lw=2)
+
+        arr2 = ax.arrow(self.psi + np.pi,
+                        0,
+                        0,
+                        self.Vb(I0) / self.Vc,
+                        alpha=0.5,
+                        width=0.015,
+                        edgecolor='red',
+                        lw=2)
+
+        arr3 = ax.arrow(self.theta_g,
+                        0,
+                        0,
+                        self.Vg / self.Vc,
+                        alpha=0.5,
+                        width=0.015,
+                        edgecolor='blue',
+                        lw=2)
 
         ax.set_rticks([])  # less radial ticks
-        plt.legend([arr1,arr2,arr3], ['Vc','Vb','Vg'],handler_map={mpatches.FancyArrow : HandlerPatch(patch_func=make_legend_arrow),})
-        
+        plt.legend([arr1, arr2, arr3], ['Vc', 'Vb', 'Vg'],
+                   handler_map={
+                       mpatches.FancyArrow:
+                       HandlerPatch(patch_func=make_legend_arrow),
+                   })
+
         return fig
-        
+
     def is_DC_Robinson_stable(self, I0):
         """
         Check DC Robinson stability - Eq. (6.1.1) [1]
@@ -869,9 +944,10 @@ class CavityResonator():
         bool
 
         """
-        return 2*self.Vc*np.sin(self.theta) + self.Vbr(I0)*np.sin(2*self.psi) > 0
-    
-    def plot_DC_Robinson_stability(self, detune_range = [-1e5,1e5]):
+        return 2 * self.Vc * np.sin(self.theta) + self.Vbr(I0) * np.sin(
+            2 * self.psi) > 0
+
+    def plot_DC_Robinson_stability(self, detune_range=[-1e5, 1e5]):
         """
         Plot DC Robinson stability limit.
 
@@ -886,46 +962,59 @@ class CavityResonator():
 
         """
         old_detune = self.psi
-        
-        x = np.linspace(detune_range[0],detune_range[1],1000)
+
+        x = np.linspace(detune_range[0], detune_range[1], 1000)
         y = []
-        for i in range(0,x.size):
+        for i in range(0, x.size):
             self.detune = x[i]
-            y.append(-self.Vc*(1+self.beta)/(self.Rs*np.sin(2*self.psi))*np.sin(self.theta)) # droite de stabilité
-            
+            y.append(-self.Vc * (1 + self.beta) /
+                     (self.Rs * np.sin(2 * self.psi)) *
+                     np.sin(self.theta))  # droite de stabilité
+
         fig = plt.figure()
         ax = plt.gca()
-        ax.plot(x,y)
+        ax.plot(x, y)
         ax.set_xlabel("Detune [Hz]")
         ax.set_ylabel("Threshold current [A]")
         ax.set_title("DC Robinson stability limit")
-        
+
         self.psi = old_detune
-        
+
         return fig
-        
-    def VRF(self, z, I0, F = 1, PHI = 0):
+
+    def VRF(self, z, I0, F=1, PHI=0):
         """Total RF voltage taking into account form factor amplitude F and form factor phase PHI"""
-        return self.Vg*np.cos(self.ring.k1*self.m*z + self.theta_g) - self.Vb(I0)*F*np.cos(self.ring.k1*self.m*z + self.psi - PHI)
-    
-    def dVRF(self, z, I0, F = 1, PHI = 0):
+        return self.Vg * np.cos(self.ring.k1 * self.m * z +
+                                self.theta_g) - self.Vb(I0) * F * np.cos(
+                                    self.ring.k1 * self.m * z + self.psi - PHI)
+
+    def dVRF(self, z, I0, F=1, PHI=0):
         """Return derivative of total RF voltage taking into account form factor amplitude F and form factor phase PHI"""
-        return -1*self.Vg*self.ring.k1*self.m*np.sin(self.ring.k1*self.m*z + self.theta_g) + self.Vb(I0)*F*self.ring.k1*self.m*np.sin(self.ring.k1*self.m*z + self.psi - PHI)
-    
-    def ddVRF(self, z, I0, F = 1, PHI = 0):
+        return -1 * self.Vg * self.ring.k1 * self.m * np.sin(
+            self.ring.k1 * self.m * z +
+            self.theta_g) + self.Vb(I0) * F * self.ring.k1 * self.m * np.sin(
+                self.ring.k1 * self.m * z + self.psi - PHI)
+
+    def ddVRF(self, z, I0, F=1, PHI=0):
         """Return the second derivative of total RF voltage taking into account form factor amplitude F and form factor phase PHI"""
-        return -1*self.Vg*(self.ring.k1*self.m)**2*np.cos(self.ring.k1*self.m*z + self.theta_g) + self.Vb(I0)*F*(self.ring.k1*self.m)**2*np.cos(self.ring.k1*self.m*z + self.psi - PHI)
-        
-    def deltaVRF(self, z, I0, F = 1, PHI = 0):
+        return -1 * self.Vg * (self.ring.k1 * self.m)**2 * np.cos(
+            self.ring.k1 * self.m * z + self.theta_g) + self.Vb(I0) * F * (
+                self.ring.k1 * self.m)**2 * np.cos(self.ring.k1 * self.m * z +
+                                                   self.psi - PHI)
+
+    def deltaVRF(self, z, I0, F=1, PHI=0):
         """Return the generator voltage minus beam loading voltage taking into account form factor amplitude F and form factor phase PHI"""
-        return -1*self.Vg*(self.ring.k1*self.m)**2*np.cos(self.ring.k1*self.m*z + self.theta_g) - self.Vb(I0)*F*(self.ring.k1*self.m)**2*np.cos(self.ring.k1*self.m*z + self.psi - PHI)
+        return -1 * self.Vg * (self.ring.k1 * self.m)**2 * np.cos(
+            self.ring.k1 * self.m * z + self.theta_g) - self.Vb(I0) * F * (
+                self.ring.k1 * self.m)**2 * np.cos(self.ring.k1 * self.m * z +
+                                                   self.psi - PHI)
 
     def update_feedback(self):
         """Force feedback update from current CavityResonator parameters."""
         for FB in self.feedback:
             if isinstance(FB, (ProportionalIntegralLoop, DirectFeedback)):
                 FB.init_Ig2Vg_matrix()
-            
+
     def sample_voltage(self, n_points=1e4, index=0):
         """
         Sample the voltage seen by a zero charge particle during an RF period.
@@ -951,28 +1040,32 @@ class CavityResonator():
         n_points = int(n_points)
         index = 0
         voltage_rec = np.zeros(n_points)
-        pos = np.linspace(-self.ring.T1/2, self.ring.T1/2, n_points)
-        DeltaT = self.ring.T1/(n_points-1)
-        
+        pos = np.linspace(-self.ring.T1 / 2, self.ring.T1 / 2, n_points)
+        DeltaT = self.ring.T1 / (n_points-1)
+
         # From t=0 of first non empty bunch to -T1/2
-        self.phasor_decay(-self.ring.T1/2 + index*self.ring.T1, 
+        self.phasor_decay(-self.ring.T1 / 2 + index * self.ring.T1,
                           ref_frame="beam")
 
         # Goes from (-T1/2) to (T1/2 + DeltaT) in n_points steps
         for i in range(n_points):
-            phase = self.m * self.ring.omega1 * (pos[i] + self.ring.T1* (index + self.ring.h * self.nturn))
-            Vgene = np.real(self.generator_phasor_record[index]*np.exp(1j*phase))
+            phase = self.m * self.ring.omega1 * (
+                pos[i] + self.ring.T1 * (index + self.ring.h * self.nturn))
+            Vgene = np.real(self.generator_phasor_record[index] *
+                            np.exp(1j * phase))
             Vbeam = np.real(self.beam_phasor)
             Vtot = Vgene + Vbeam
             voltage_rec[i] = Vtot
             self.phasor_decay(DeltaT, ref_frame="beam")
 
         # Get back to t=0
-        self.phasor_decay(-DeltaT*n_points + self.ring.T1/2 - index*self.ring.T1, 
+        self.phasor_decay(-DeltaT * n_points + self.ring.T1 / 2 -
+                          index * self.ring.T1,
                           ref_frame="beam")
-        
+
         return pos, voltage_rec
 
+
 class ProportionalLoop():
     """
     Proportional feedback loop to control a CavityResonator amplitude and phase.
@@ -1004,9 +1097,9 @@ class ProportionalLoop():
         self.delay = int(delay)
         if delay < 1:
             raise ValueError("delay must be >= 1.")
-        self.volt_delay = np.ones(self.delay)*self.cav_res.Vc
-        self.phase_delay = np.ones(self.delay)*self.cav_res.theta
-    
+        self.volt_delay = np.ones(self.delay) * self.cav_res.Vc
+        self.phase_delay = np.ones(self.delay) * self.cav_res.theta
+
     def track(self):
         """
         Tracking method for the amplitude and phase loop.
@@ -1018,14 +1111,16 @@ class ProportionalLoop():
         """
         diff_A = self.volt_delay[-1] - self.cav_res.Vc
         diff_P = self.phase_delay[-1] - self.cav_res.theta
-        self.cav_res.Vg -= self.gain_A*diff_A
-        self.cav_res.theta_g -= self.gain_P*diff_P
-        self.cav_res.generator_phasor_record = np.ones(self.ring.h)*self.cav_res.generator_phasor
+        self.cav_res.Vg -= self.gain_A * diff_A
+        self.cav_res.theta_g -= self.gain_P * diff_P
+        self.cav_res.generator_phasor_record = np.ones(
+            self.ring.h) * self.cav_res.generator_phasor
         self.volt_delay = np.roll(self.volt_delay, 1)
         self.phase_delay = np.roll(self.phase_delay, 1)
         self.volt_delay[0] = self.cav_res.cavity_voltage
         self.phase_delay[0] = self.cav_res.cavity_phase
 
+
 class TunerLoop():
     """
     Cavity tuner loop used to control a CavityResonator tuning (psi or detune)
@@ -1055,20 +1150,20 @@ class TunerLoop():
         Tuning offset in [rad].
         
     """
-    def __init__(self, ring, cav_res, gain=0.01, avering_period=0, 
-                 offset=0):
+    def __init__(self, ring, cav_res, gain=0.01, avering_period=0, offset=0):
         self.ring = ring
         self.cav_res = cav_res
         if avering_period == 0:
-            fs = self.ring.synchrotron_tune(self.cav_res.Vc)*self.ring.f1/self.ring.h
-            avering_period = 2/fs/self.ring.T0
-    
+            fs = self.ring.synchrotron_tune(
+                self.cav_res.Vc) * self.ring.f1 / self.ring.h
+            avering_period = 2 / fs / self.ring.T0
+
         self.Pgain = gain
         self.offset = offset
         self.avering_period = int(avering_period)
         self.diff = 0
         self.count = 0
-    
+
     def track(self):
         """
         Tracking method for the tuner loop.
@@ -1079,7 +1174,7 @@ class TunerLoop():
     
         """
         if self.count == self.avering_period:
-            diff = self.diff/self.avering_period - self.offset
+            diff = self.diff / self.avering_period - self.offset
             self.cav_res.psi -= diff * self.Pgain
             self.count = 0
             self.diff = 0
@@ -1087,6 +1182,7 @@ class TunerLoop():
             self.diff += self.cav_res.cavity_phase - self.cav_res.theta_g + self.cav_res.psi
             self.count += 1
 
+
 class ProportionalIntegralLoop():
     """
     Proportional Integral (PI) loop to control a CavityResonator amplitude and 
@@ -1198,20 +1294,22 @@ class ProportionalIntegralLoop():
             self.every = int(every)
         else:
             self.every = 1
-        record_size = int(np.ceil(self.delay/self.every))
+        record_size = int(np.ceil(self.delay / self.every))
         if record_size < 1:
             raise ValueError("Bad parameter set : delay or every")
         self.sample_num = int(sample_num)
 
         # init lists for FB process
-        self.ig_phasor = np.ones(self.ring.h, dtype=complex)*self.Vg2Ig(self.cav_res.generator_phasor)
+        self.ig_phasor = np.ones(self.ring.h, dtype=complex) * self.Vg2Ig(
+            self.cav_res.generator_phasor)
         self.ig_phasor_record = self.ig_phasor
-        self.vc_previous = np.ones(self.sample_num)*self.cav_res.cavity_phasor
+        self.vc_previous = np.ones(
+            self.sample_num) * self.cav_res.cavity_phasor
         self.diff_record = np.zeros(record_size, dtype=complex)
-        self.I_record = 0+0j
+        self.I_record = 0 + 0j
 
         self.sample_list = range(0, self.ring.h, self.every)
-        
+
         self.init_FFconst()
 
         # Pre caclulation for Ig2Vg
@@ -1226,30 +1324,34 @@ class ProportionalIntegralLoop():
         None.
 
         """
-        vc_list = np.concatenate([self.vc_previous, self.cav_res.cavity_phasor_record]) #This line is slowing down the process.
+        vc_list = np.concatenate([
+            self.vc_previous, self.cav_res.cavity_phasor_record
+        ])  #This line is slowing down the process.
         self.ig_phasor.fill(self.ig_phasor[-1])
-        
-        for index in self.sample_list: 
+
+        for index in self.sample_list:
             # 2) updating Ig using last item of the list
             diff = self.diff_record[-1] - self.FFconst
-            self.I_record += diff/self.ring.f1
-            fb_value = self.gain[0]*diff + self.gain[1]*self.I_record
+            self.I_record += diff / self.ring.f1
+            fb_value = self.gain[0] * diff + self.gain[1] * self.I_record
             self.ig_phasor[index:] = self.Vg2Ig(fb_value) + self.FFconst
             # Shift the record
             self.diff_record = np.roll(self.diff_record, 1)
             # 1) recording diff as a first item of the list
-            mean_vc = np.mean(vc_list[index:self.sample_num+index])*np.exp(-1j*self.cav_res.theta)
+            mean_vc = np.mean(vc_list[index:self.sample_num + index]) * np.exp(
+                -1j * self.cav_res.theta)
             self.diff_record[0] = self.cav_res.Vc - mean_vc
         # update sample_list for next turn
-        self.sample_list = range(index+self.every-self.ring.h,self.ring.h,self.every)
+        self.sample_list = range(index + self.every - self.ring.h, self.ring.h,
+                                 self.every)
         # update vc_previous for next turn
-        self.vc_previous = self.cav_res.cavity_phasor_record[- self.sample_num:]
-        
+        self.vc_previous = self.cav_res.cavity_phasor_record[-self.sample_num:]
+
         self.ig_phasor_record = self.ig_phasor
-        
+
         if apply_changes:
             self.Ig2Vg()
-            
+
     def init_Ig2Vg_matrix(self):
         """
         Initialize matrix for Ig2Vg_matrix.
@@ -1258,28 +1360,33 @@ class ProportionalIntegralLoop():
         parameter change.
         """
         k = np.arange(0, self.ring.h)
-        self.Ig2Vg_vec = np.exp(-1/self.cav_res.filling_time*(1 - 1j*np.tan(self.cav_res.psi))*self.ring.T1*(k+1))
-        tempV = np.exp(-1/self.cav_res.filling_time*self.ring.T1*k*(1 - 1j*np.tan(self.cav_res.psi)))
+        self.Ig2Vg_vec = np.exp(-1 / self.cav_res.filling_time *
+                                (1 - 1j * np.tan(self.cav_res.psi)) *
+                                self.ring.T1 * (k+1))
+        tempV = np.exp(-1 / self.cav_res.filling_time * self.ring.T1 * k *
+                       (1 - 1j * np.tan(self.cav_res.psi)))
         for idx in np.arange(self.ring.h):
-            self.Ig2Vg_mat[idx:,idx] = tempV[:self.ring.h-idx]
-            
+            self.Ig2Vg_mat[idx:, idx] = tempV[:self.ring.h - idx]
+
     def init_FFconst(self):
         """Initialize feedforward constant."""
         if self.FF:
             self.FFconst = np.mean(self.ig_phasor)
         else:
             self.FFconst = 0
-            
+
     def Ig2Vg_matrix(self):
         """
         Return Vg from Ig using matrix formalism.
         Warning: self.init_Ig2Vg should be called after each CavityResonator 
         parameter change.
         """
-        generator_phasor_record = (self.Ig2Vg_vec*self.cav_res.generator_phasor_record[-1] +
-                    self.Ig2Vg_mat.dot(self.ig_phasor_record)*self.cav_res.loss_factor*self.ring.T1)
+        generator_phasor_record = (
+            self.Ig2Vg_vec * self.cav_res.generator_phasor_record[-1] +
+            self.Ig2Vg_mat.dot(self.ig_phasor_record) *
+            self.cav_res.loss_factor * self.ring.T1)
         return generator_phasor_record
-            
+
     def Ig2Vg(self):
         """
         Go from Ig to Vg.
@@ -1289,15 +1396,17 @@ class ProportionalIntegralLoop():
         """
         self.cav_res.generator_phasor_record = self.Ig2Vg_matrix()
         self.cav_res.Vg = np.mean(np.abs(self.cav_res.generator_phasor_record))
-        self.cav_res.theta_g = np.mean(np.angle(self.cav_res.generator_phasor_record))
-        
+        self.cav_res.theta_g = np.mean(
+            np.angle(self.cav_res.generator_phasor_record))
+
     def Vg2Ig(self, Vg):
         """
         Return Ig from Vg (assuming constant Vg).
 
         Eq.25 of ref [2] assuming the dVg/dt = 0.
         """
-        return Vg * ( 1 - 1j*np.tan(self.cav_res.psi) ) / self.cav_res.RL
+        return Vg * (1 - 1j * np.tan(self.cav_res.psi)) / self.cav_res.RL
+
 
 class DirectFeedback(ProportionalIntegralLoop):
     """
@@ -1371,10 +1480,15 @@ class DirectFeedback(ProportionalIntegralLoop):
     ring. In Proc. IPAC'23. doi:10.18429/JACoW-IPAC2023-WEPL161
 
     """
-    def __init__(self, DFB_gain, DFB_phase_shift, DFB_sample_num=None,
-                 DFB_every=None, DFB_delay=None, **kwargs):
+    def __init__(self,
+                 DFB_gain,
+                 DFB_phase_shift,
+                 DFB_sample_num=None,
+                 DFB_every=None,
+                 DFB_delay=None,
+                 **kwargs):
         super(DirectFeedback, self).__init__(**kwargs)
-        
+
         if DFB_delay is not None:
             self.DFB_delay = int(DFB_delay)
         else:
@@ -1384,37 +1498,38 @@ class DirectFeedback(ProportionalIntegralLoop):
             self.DFB_sample_num = int(DFB_sample_num)
         else:
             self.DFB_sample_num = self.sample_num
-            
+
         if DFB_every is not None:
             self.DFB_every = int(DFB_every)
         else:
             self.DFB_every = self.every
 
-        record_size = int(np.ceil(self.DFB_delay/self.DFB_every))
+        record_size = int(np.ceil(self.DFB_delay / self.DFB_every))
         if record_size < 1:
             raise ValueError("Bad parameter set : DFB_delay or DFB_every")
 
         self.DFB_parameter_set(DFB_gain, DFB_phase_shift)
         if np.sum(np.abs(self.cav_res.beam_phasor)) == 0:
-            cavity_phasor = self.cav_res.Vc*np.exp(1j*self.cav_res.theta)
+            cavity_phasor = self.cav_res.Vc * np.exp(1j * self.cav_res.theta)
         else:
             cavity_phasor = np.mean(self.cav_res.cavity_phasor_record)
-        self.DFB_VcRecord = np.ones(record_size, dtype=complex)*cavity_phasor
-        self.DFB_vc_previous = np.ones(self.DFB_sample_num, dtype=complex)*cavity_phasor 
+        self.DFB_VcRecord = np.ones(record_size, dtype=complex) * cavity_phasor
+        self.DFB_vc_previous = np.ones(self.DFB_sample_num,
+                                       dtype=complex) * cavity_phasor
 
         self.DFB_sample_list = range(0, self.ring.h, self.DFB_every)
-        
+
     @property
     def DFB_phase_shift(self):
         """Return DFB phase shift."""
         return self._DFB_phase_shift
-    
+
     @DFB_phase_shift.setter
     def DFB_phase_shift(self, value):
         """Set DFB_phase_shift and phase_shift"""
         self._DFB_phase_shift = value
         self._phase_shift = self.cav_res.psi - value
-       
+
     @property
     def phase_shift(self):
         """
@@ -1422,7 +1537,7 @@ class DirectFeedback(ProportionalIntegralLoop):
         Defined as self.cav_res.psi - self.DFB_phase_shift.
         """
         return self._phase_shift
-    
+
     @property
     def DFB_psi(self):
         """
@@ -1430,9 +1545,9 @@ class DirectFeedback(ProportionalIntegralLoop):
     
         Fig.4 of ref [1].
         """
-        return (np.angle(np.mean(self.cav_res.cavity_phasor_record)) - 
+        return (np.angle(np.mean(self.cav_res.cavity_phasor_record)) -
                 np.angle(np.mean(self.ig_phasor_record)))
-    
+
     @property
     def DFB_alpha(self):
         """
@@ -1440,9 +1555,10 @@ class DirectFeedback(ProportionalIntegralLoop):
         
         Near Eq. (13) of [1].
         """
-        fac = np.abs(np.mean(self.DFB_ig_phasor)/np.mean(self.ig_phasor_record))
-        return 20*np.log10(fac)
-    
+        fac = np.abs(
+            np.mean(self.DFB_ig_phasor) / np.mean(self.ig_phasor_record))
+        return 20 * np.log10(fac)
+
     @property
     def DFB_gamma(self):
         """
@@ -1450,9 +1566,10 @@ class DirectFeedback(ProportionalIntegralLoop):
         
         Near Eq. (13) of [1].
         """
-        fac = np.abs(np.mean(self.DFB_ig_phasor)/np.mean(self.ig_phasor_record))
-        return fac/(1-fac)
-    
+        fac = np.abs(
+            np.mean(self.DFB_ig_phasor) / np.mean(self.ig_phasor_record))
+        return fac / (1-fac)
+
     @property
     def DFB_Rs(self):
         """
@@ -1460,8 +1577,8 @@ class DirectFeedback(ProportionalIntegralLoop):
         
         Eq. (15) of [1].
         """
-        return self.cav_res.Rs/(1+self.DFB_gamma*np.cos(self.DFB_psi))
-        
+        return self.cav_res.Rs / (1 + self.DFB_gamma * np.cos(self.DFB_psi))
+
     def DFB_parameter_set(self, DFB_gain, DFB_phase_shift):
         """
         Set DFB gain and phase shift parameters.
@@ -1478,10 +1595,11 @@ class DirectFeedback(ProportionalIntegralLoop):
         self.DFB_phase_shift = DFB_phase_shift
 
         if np.sum(np.abs(self.cav_res.beam_phasor)) == 0:
-            vc = np.ones(self.ring.h)*self.cav_res.Vc*np.exp(1j*self.cav_res.theta)
+            vc = np.ones(self.ring.h) * self.cav_res.Vc * np.exp(
+                1j * self.cav_res.theta)
         else:
             vc = self.cav_res.cavity_phasor_record
-        vg_drf = self.DFB_gain*vc*np.exp(1j*self.phase_shift)
+        vg_drf = self.DFB_gain * vc * np.exp(1j * self.phase_shift)
         self.DFB_ig_phasor = self.Vg2Ig(vg_drf)
 
         self.ig_phasor = self.ig_phasor_record - self.DFB_ig_phasor
@@ -1497,44 +1615,50 @@ class DirectFeedback(ProportionalIntegralLoop):
 
         """
         super(DirectFeedback, self).track(False)
-        
-        vc_list = np.concatenate([self.DFB_vc_previous, self.cav_res.cavity_phasor_record])
+
+        vc_list = np.concatenate(
+            [self.DFB_vc_previous, self.cav_res.cavity_phasor_record])
         self.DFB_ig_phasor = np.roll(self.DFB_ig_phasor, 1)
         for index in self.DFB_sample_list:
             # 2) updating Ig using last item of the list
-            vg_drf = self.DFB_gain*self.DFB_VcRecord[-1]*np.exp(1j*self.phase_shift)
+            vg_drf = self.DFB_gain * self.DFB_VcRecord[-1] * np.exp(
+                1j * self.phase_shift)
             self.DFB_ig_phasor[index:] = self.Vg2Ig(vg_drf)
             # Shift the record
             self.DFB_VcRecord = np.roll(self.DFB_VcRecord, 1)
             # 1) recording Vc
-            mean_vc = np.mean(vc_list[index:self.DFB_sample_num+index])
+            mean_vc = np.mean(vc_list[index:self.DFB_sample_num + index])
             self.DFB_VcRecord[0] = mean_vc
         # update sample_list for next turn
-        self.DFB_sample_list = range(index+self.DFB_every-self.ring.h, self.ring.h, self.DFB_every)
+        self.DFB_sample_list = range(index + self.DFB_every - self.ring.h,
+                                     self.ring.h, self.DFB_every)
         # update vc_previous for next turn
-        self.DFB_vc_previous = self.cav_res.cavity_phasor_record[- self.DFB_sample_num:]
+        self.DFB_vc_previous = self.cav_res.cavity_phasor_record[
+            -self.DFB_sample_num:]
 
         self.ig_phasor_record = self.ig_phasor + self.DFB_ig_phasor
-        
+
         self.Ig2Vg()
-    
+
     def DFB_Vg(self, vc=-1):
         """Return the generator voltage main and DFB components in [V]."""
         if vc == -1:
             vc = np.mean(self.cav_res.cavity_phasor_record)
-        vg_drf=self.DFB_gain*vc*np.exp(1j*self.phase_shift)
-        vg_main=np.mean(self.cav_res.generator_phasor_record)-vg_drf
+        vg_drf = self.DFB_gain * vc * np.exp(1j * self.phase_shift)
+        vg_main = np.mean(self.cav_res.generator_phasor_record) - vg_drf
         return vg_main, vg_drf
 
     def DFB_fs(self, vg_main=-1, vg_drf=-1):
         """Return the modified synchrotron frequency in [Hz]."""
         vc = np.mean(self.cav_res.cavity_phasor_record)
-        if vg_drf ==-1:
-            vg_drf=self.DFB_gain*vc*np.exp(1j*self.phase_shift)
+        if vg_drf == -1:
+            vg_drf = self.DFB_gain * vc * np.exp(1j * self.phase_shift)
         if vg_main == -1:
-            vg_main=np.mean(self.cav_res.generator_phasor_record)-vg_drf
-        vg_sum = np.abs(vg_main)*np.sin(np.angle(vg_main))+np.abs(vg_drf)*np.sin(np.angle(vg_drf))
+            vg_main = np.mean(self.cav_res.generator_phasor_record) - vg_drf
+        vg_sum = np.abs(vg_main) * np.sin(
+            np.angle(vg_main)) + np.abs(vg_drf) * np.sin(np.angle(vg_drf))
         omega_s = 0
         if (vg_sum) > 0.0:
-            omega_s=np.sqrt(self.ring.ac*self.ring.omega1*(vg_sum)/self.ring.E0/self.ring.T0)
-        return omega_s/2/np.pi
+            omega_s = np.sqrt(self.ring.ac * self.ring.omega1 * (vg_sum) /
+                              self.ring.E0 / self.ring.T0)
+        return omega_s / 2 / np.pi
diff --git a/mbtrack2/tracking/synchrotron.py b/mbtrack2/tracking/synchrotron.py
index d5c5d0eaa835a92f6bc5e9bba8ddb9a86e56e4e5..24fe189a8833c0c7ebde7ed24f77989c7ecfef88 100644
--- a/mbtrack2/tracking/synchrotron.py
+++ b/mbtrack2/tracking/synchrotron.py
@@ -3,10 +3,11 @@
 Module where the Synchrotron class is defined.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
+import numpy as np
 from scipy.constants import c, e
-        
+
+
 class Synchrotron:
     """
     Synchrotron class to store main properties.
@@ -121,7 +122,7 @@ class Synchrotron:
         self.long_alpha = 0
         self.long_beta = 0
         self.long_gamma = 0
-        
+
         if self.optics.use_local_values == False:
             self.L = kwargs.get('L', self.optics.lattice.circumference)
             self.E0 = kwargs.get('E0', self.optics.lattice.energy)
@@ -130,109 +131,113 @@ class Synchrotron:
             self.chro = kwargs.get('chro', self.optics.chro)
             self.U0 = kwargs.get('U0', self.optics.lattice.energy_loss)
         else:
-            self.L = kwargs.get('L') # Ring circumference [m]
-            self.E0 = kwargs.get('E0') # Nominal (total) energy of the ring [eV]
-            self.ac = kwargs.get('ac') # Momentum compaction factor
-            self.tune = kwargs.get('tune') # X/Y/S tunes
-            self.chro = kwargs.get('chro') # X/Y (non-normalized) chromaticities
-            self.U0 = kwargs.get('U0') # Energy loss per turn [eV]
-            
-        self.tau = kwargs.get('tau') # X/Y/S damping times [s]
-        self.sigma_delta = kwargs.get('sigma_delta') # Equilibrium energy spread
-        self.sigma_0 = kwargs.get('sigma_0') # Natural bunch length [s]
-        self.emit = kwargs.get('emit') # X/Y emittances in [m.rad]
-        self.adts = kwargs.get('adts') # Amplitude-Dependent Tune Shift (ADTS)
-        self.mcf_order = kwargs.get('mcf_order', self.ac) # Higher-orders momentum compaction factor
-        
+            self.L = kwargs.get('L')  # Ring circumference [m]
+            self.E0 = kwargs.get(
+                'E0')  # Nominal (total) energy of the ring [eV]
+            self.ac = kwargs.get('ac')  # Momentum compaction factor
+            self.tune = kwargs.get('tune')  # X/Y/S tunes
+            self.chro = kwargs.get(
+                'chro')  # X/Y (non-normalized) chromaticities
+            self.U0 = kwargs.get('U0')  # Energy loss per turn [eV]
+
+        self.tau = kwargs.get('tau')  # X/Y/S damping times [s]
+        self.sigma_delta = kwargs.get(
+            'sigma_delta')  # Equilibrium energy spread
+        self.sigma_0 = kwargs.get('sigma_0')  # Natural bunch length [s]
+        self.emit = kwargs.get('emit')  # X/Y emittances in [m.rad]
+        self.adts = kwargs.get('adts')  # Amplitude-Dependent Tune Shift (ADTS)
+        self.mcf_order = kwargs.get(
+            'mcf_order', self.ac)  # Higher-orders momentum compaction factor
+
     @property
     def h(self):
         """Harmonic number"""
         return self._h
-    
+
     @h.setter
     def h(self, value):
         self._h = value
         self.L = self.L  # call setter
-        
+
     @property
     def L(self):
         """Ring circumference [m]"""
         return self._L
-    
+
     @L.setter
-    def L(self,value):
+    def L(self, value):
         self._L = value
-        self._T0 = self.L/c
-        self._T1 = self.T0/self.h
-        self._f0 = 1/self.T0
-        self._omega0 = 2*np.pi*self.f0
-        self._f1 = self.h*self.f0
-        self._omega1 = 2*np.pi*self.f1
-        self._k1 = self.omega1/c
-        
+        self._T0 = self.L / c
+        self._T1 = self.T0 / self.h
+        self._f0 = 1 / self.T0
+        self._omega0 = 2 * np.pi * self.f0
+        self._f1 = self.h * self.f0
+        self._omega1 = 2 * np.pi * self.f1
+        self._k1 = self.omega1 / c
+
     @property
     def T0(self):
         """Revolution time [s]"""
         return self._T0
-    
+
     @T0.setter
     def T0(self, value):
-        self.L = c*value
-        
+        self.L = c * value
+
     @property
     def T1(self):
         """"Fundamental RF period [s]"""
         return self._T1
-    
+
     @T1.setter
     def T1(self, value):
-        self.L = c*value*self.h
-        
+        self.L = c * value * self.h
+
     @property
     def f0(self):
         """Revolution frequency [Hz]"""
         return self._f0
-    
+
     @f0.setter
-    def f0(self,value):
-        self.L = c/value
-        
+    def f0(self, value):
+        self.L = c / value
+
     @property
     def omega0(self):
         """Angular revolution frequency [Hz rad]"""
         return self._omega0
-    
+
     @omega0.setter
-    def omega0(self,value):
-        self.L = 2*np.pi*c/value
-        
+    def omega0(self, value):
+        self.L = 2 * np.pi * c / value
+
     @property
     def f1(self):
         """Fundamental RF frequency [Hz]"""
         return self._f1
-    
+
     @f1.setter
-    def f1(self,value):
-        self.L = self.h*c/value
-        
+    def f1(self, value):
+        self.L = self.h * c / value
+
     @property
     def omega1(self):
         """Fundamental RF angular frequency[Hz rad]"""
         return self._omega1
-    
+
     @omega1.setter
-    def omega1(self,value):
-        self.L = 2*np.pi*self.h*c/value
-        
+    def omega1(self, value):
+        self.L = 2 * np.pi * self.h * c / value
+
     @property
     def k1(self):
         """Fundamental RF wave number [m**-1]"""
         return self._k1
-    
+
     @k1.setter
-    def k1(self,value):
-        self.L = 2*np.pi*self.h/value
-    
+    def k1(self, value):
+        self.L = 2 * np.pi * self.h / value
+
     @property
     def gamma(self):
         """Relativistic gamma"""
@@ -242,7 +247,7 @@ class Synchrotron:
     def gamma(self, value):
         self._gamma = value
         self._beta = np.sqrt(1 - self.gamma**-2)
-        self._E0 = self.gamma*self.particle.mass*c**2/e
+        self._E0 = self.gamma * self.particle.mass * c**2 / e
 
     @property
     def beta(self):
@@ -251,27 +256,27 @@ class Synchrotron:
 
     @beta.setter
     def beta(self, value):
-        self.gamma = 1/np.sqrt(1-value**2)
-        
+        self.gamma = 1 / np.sqrt(1 - value**2)
+
     @property
     def E0(self):
         """Nominal (total) energy of the ring [eV]"""
         return self._E0
-    
+
     @E0.setter
     def E0(self, value):
-        self.gamma = value/(self.particle.mass*c**2/e)
+        self.gamma = value / (self.particle.mass * c**2 / e)
 
     @property
     def mcf_order(self):
         """Higher-orders momentum compaction factor"""
         return self._mcf_order
-    
+
     @mcf_order.setter
     def mcf_order(self, value):
         self._mcf_order = value
         self.mcf = np.poly1d(self.mcf_order)
-        
+
     def eta(self, delta=0):
         """
         Momentum compaction taking into account higher orders if provided in
@@ -288,8 +293,8 @@ class Synchrotron:
             Momentum compaction.
 
         """
-        return self.mcf(delta) - 1/(self.gamma**2)
-    
+        return self.mcf(delta) - 1 / (self.gamma**2)
+
     def sigma(self, position=None):
         """
         Return the RMS beam size at equilibrium in [m].
@@ -308,31 +313,39 @@ class Synchrotron:
 
         """
         if position is None:
-            sigma = np.zeros((4,))
-            sigma[0] = (self.emit[0]*self.optics.local_beta[0] +
-                        self.optics.local_dispersion[0]**2*self.sigma_delta**2)**0.5
-            sigma[1] = (self.emit[0]*self.optics.local_gamma[0] +
-                        self.optics.local_dispersion[1]**2*self.sigma_delta**2)**0.5
-            sigma[2] = (self.emit[1]*self.optics.local_beta[1] +
-                        self.optics.local_dispersion[2]**2*self.sigma_delta**2)**0.5
-            sigma[3] = (self.emit[1]*self.optics.local_gamma[1] +
-                        self.optics.local_dispersion[3]**2*self.sigma_delta**2)**0.5
+            sigma = np.zeros((4, ))
+            sigma[0] = (
+                self.emit[0] * self.optics.local_beta[0] +
+                self.optics.local_dispersion[0]**2 * self.sigma_delta**2)**0.5
+            sigma[1] = (
+                self.emit[0] * self.optics.local_gamma[0] +
+                self.optics.local_dispersion[1]**2 * self.sigma_delta**2)**0.5
+            sigma[2] = (
+                self.emit[1] * self.optics.local_beta[1] +
+                self.optics.local_dispersion[2]**2 * self.sigma_delta**2)**0.5
+            sigma[3] = (
+                self.emit[1] * self.optics.local_gamma[1] +
+                self.optics.local_dispersion[3]**2 * self.sigma_delta**2)**0.5
         else:
             if isinstance(position, (float, int)):
                 n = 1
             else:
                 n = len(position)
             sigma = np.zeros((4, n))
-            sigma[0,:] = (self.emit[0]*self.optics.beta(position)[0] +
-                        self.optics.dispersion(position)[0]**2*self.sigma_delta**2)**0.5
-            sigma[1,:] = (self.emit[0]*self.optics.gamma(position)[0] +
-                        self.optics.dispersion(position)[1]**2*self.sigma_delta**2)**0.5
-            sigma[2,:] = (self.emit[1]*self.optics.beta(position)[1] +
-                        self.optics.dispersion(position)[2]**2*self.sigma_delta**2)**0.5
-            sigma[3,:] = (self.emit[1]*self.optics.gamma(position)[1] +
-                        self.optics.dispersion(position)[3]**2*self.sigma_delta**2)**0.5
+            sigma[0, :] = (self.emit[0] * self.optics.beta(position)[0] +
+                           self.optics.dispersion(position)[0]**2 *
+                           self.sigma_delta**2)**0.5
+            sigma[1, :] = (self.emit[0] * self.optics.gamma(position)[0] +
+                           self.optics.dispersion(position)[1]**2 *
+                           self.sigma_delta**2)**0.5
+            sigma[2, :] = (self.emit[1] * self.optics.beta(position)[1] +
+                           self.optics.dispersion(position)[2]**2 *
+                           self.sigma_delta**2)**0.5
+            sigma[3, :] = (self.emit[1] * self.optics.gamma(position)[1] +
+                           self.optics.dispersion(position)[3]**2 *
+                           self.sigma_delta**2)**0.5
         return sigma
-    
+
     def synchrotron_tune(self, V):
         """
         Compute the (unperturbed) synchrotron tune from main RF voltage.
@@ -348,11 +361,11 @@ class Synchrotron:
             Synchrotron tune.
             
         """
-        Vsum =  V * np.sin(np.arccos(self.U0/V))
+        Vsum = V * np.sin(np.arccos(self.U0 / V))
         phi = np.arccos(1 - self.eta(0) * np.pi * self.h / self.E0 * Vsum)
         tuneS = phi / (2 * np.pi)
         return tuneS
-    
+
     def get_adts(self):
         """
         Compute and add Amplitude-Dependent Tune Shifts (ADTS) sextupolar 
@@ -360,16 +373,16 @@ class Synchrotron:
         """
         import at
         if self.optics.use_local_values:
-            raise ValueError("ADTS needs to be provided manualy as no AT" + 
+            raise ValueError("ADTS needs to be provided manualy as no AT" +
                              " lattice file is loaded.")
-            
+
         det = at.physics.nonlinear.gen_detuning_elem(self.optics.lattice)
-        coef_xx = np.array([det.A1/2, 0])
-        coef_yx = np.array([det.A2/2, 0])
-        coef_xy = np.array([det.A2/2, 0])
-        coef_yy = np.array([det.A3/2, 0])
+        coef_xx = np.array([det.A1 / 2, 0])
+        coef_yx = np.array([det.A2 / 2, 0])
+        coef_xy = np.array([det.A2 / 2, 0])
+        coef_yy = np.array([det.A3 / 2, 0])
         self.adts = [coef_xx, coef_yx, coef_xy, coef_yy]
-        
+
     def get_mcf_order(self, add=True, show_fit=False):
         """
         Compute momentum compaction factor up to 3rd order from AT lattice.
@@ -391,7 +404,7 @@ class Synchrotron:
         """
         import at
         if self.optics.use_local_values:
-            raise ValueError("ADTS needs to be provided manualy as no AT" + 
+            raise ValueError("ADTS needs to be provided manualy as no AT" +
                              " lattice file is loaded.")
         deltamin = -1e-4
         deltamax = 1e-4
@@ -401,21 +414,21 @@ class Synchrotron:
         alpha = np.zeros_like(delta)
 
         for i in range(len(delta)):
-            alpha[i] = at.physics.revolution.get_mcf(self.optics.lattice, 
+            alpha[i] = at.physics.revolution.get_mcf(self.optics.lattice,
                                                      delta[i])
         pvalue = np.polyfit(delta, alpha, 2)
 
         if show_fit:
             pvalue = np.polyfit(delta, alpha, 2)
             palpha = np.polyval(pvalue, delta4eval)
-    
-            plt.plot(delta*100, alpha,'k.')
+
+            plt.plot(delta * 100, alpha, 'k.')
             plt.grid()
-            plt.plot(delta4eval*100,palpha,'r')
+            plt.plot(delta4eval * 100, palpha, 'r')
             plt.xlabel('Energy (%)')
             plt.ylabel('Momemtum compaction factor')
             plt.legend(['Data', 'polyfit'])
-            
+
         if add:
             self.mcf_order = pvalue
         else:
@@ -463,25 +476,26 @@ class Synchrotron:
         long_gamma : float
             Longitudinal gamma Twiss parameter at the tracking location in [s-1].
         """
-        
+
         if isinstance(V, float) or isinstance(V, int):
             V = [V]
             if phase is None:
-                phase = [np.arccos(self.U0/V[0])]
+                phase = [np.arccos(self.U0 / V[0])]
             elif isinstance(phase, float) or isinstance(phase, int):
                 phase = [phase]
             if harmonics is None:
                 harmonics = [1]
-                
+
         if not (len(V) == len(phase) == len(harmonics)):
             raise ValueError("You must provide array of the same length for"
                              " V, phase and harmonics")
-        
+
         Vsum = 0
         for i in range(len(V)):
             Vsum += harmonics[i] * V[i] * np.sin(phase[i])
         phi = np.arccos(1 - self.eta(0) * np.pi * self.h / self.E0 * Vsum)
-        long_alpha = - self.eta(0) * np.pi * self.h / (self.E0 * np.sin(phi)) * Vsum
+        long_alpha = -self.eta(0) * np.pi * self.h / (self.E0 *
+                                                      np.sin(phi)) * Vsum
         long_beta = self.eta(0) * self.T0 / np.sin(phi)
         long_gamma = self.omega1 * Vsum / (self.E0 * np.sin(phi))
         tuneS = phi / (2 * np.pi)
@@ -491,4 +505,4 @@ class Synchrotron:
             self.long_beta = long_beta
             self.long_gamma = long_gamma
         else:
-            return tuneS, long_alpha, long_beta, long_gamma
\ No newline at end of file
+            return tuneS, long_alpha, long_beta, long_gamma
diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py
index 0d30f44f57964e780279f38a255c8a8cb7e3c500..5881a2d706b5c041910b3913706e7004d582c847 100644
--- a/mbtrack2/tracking/wakepotential.py
+++ b/mbtrack2/tracking/wakepotential.py
@@ -4,15 +4,17 @@ This module defines the WakePotential and LongRangeResistiveWall classes which
 deal with the single bunch and multi-bunch wakes.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
+import numpy as np
 import pandas as pd
 from scipy import signal
+from scipy.constants import c, mu_0, pi
 from scipy.interpolate import interp1d
-from scipy.constants import mu_0, c, pi
+
 from mbtrack2.tracking.element import Element
 from mbtrack2.utilities.spectrum import gaussian_bunch
-   
+
+
 class WakePotential(Element):
     """
     Compute a wake potential from uniformly sampled wake functions by 
@@ -69,7 +71,6 @@ class WakePotential(Element):
         Reduce wake function samping by an integer factor.
         
     """
-    
     def __init__(self, ring, wakefield, n_bin=80):
         self.wakefield = wakefield
         self.types = self.wakefield.wake_components
@@ -77,10 +78,10 @@ class WakePotential(Element):
         self.ring = ring
         self.n_bin = n_bin
         self.check_sampling()
-        
+
         # Suppress numpy warning for floating-point operations.
         np.seterr(invalid='ignore')
-            
+
     def charge_density(self, bunch):
         """
         Compute bunch charge density profile in [1/s].
@@ -90,7 +91,7 @@ class WakePotential(Element):
         bunch : Bunch object
 
         """
-        
+
         # Get binning data
         a, b, c, d = bunch.binning(n_bin=self.n_bin)
         self.bins = a
@@ -98,37 +99,36 @@ class WakePotential(Element):
         self.profile = c
         self.center = d
         self.bin_size = self.bins[1] - self.bins[0]
-        
+
         # Compute charge density
-        self.rho = bunch.charge_per_mp*self.profile/(self.bin_size*bunch.charge)
+        self.rho = bunch.charge_per_mp * self.profile / (self.bin_size *
+                                                         bunch.charge)
         self.rho = np.array(self.rho)
-        
+
         # Compute time array
         self.tau = np.array(self.center)
         self.dtau = self.tau[1] - self.tau[0]
-        
+
         # Add N values before and after rho and tau
         if self.n_bin % 2 == 0:
-            N = int(self.n_bin/2)
-            self.tau = np.arange(self.tau[0] - self.dtau*N, 
-                                 self.tau[-1] + self.dtau*N,
-                                 self.dtau)
+            N = int(self.n_bin / 2)
+            self.tau = np.arange(self.tau[0] - self.dtau * N,
+                                 self.tau[-1] + self.dtau * N, self.dtau)
             self.rho = np.append(self.rho, np.zeros(N))
             self.rho = np.insert(self.rho, 0, np.zeros(N))
         else:
-            N = int(np.floor(self.n_bin/2))
-            self.tau = np.arange(self.tau[0] - self.dtau*N, 
-                                 self.tau[-1] + self.dtau*(N+1),
-                                 self.dtau)
+            N = int(np.floor(self.n_bin / 2))
+            self.tau = np.arange(self.tau[0] - self.dtau * N,
+                                 self.tau[-1] + self.dtau * (N+1), self.dtau)
             self.rho = np.append(self.rho, np.zeros(N))
-            self.rho = np.insert(self.rho, 0, np.zeros(N+1))
-            
+            self.rho = np.insert(self.rho, 0, np.zeros(N + 1))
+
         if len(self.tau) != len(self.rho):
             self.tau = np.append(self.tau, self.tau[-1] + self.dtau)
-            
+
         self.tau_mean = np.mean(self.tau)
         self.tau -= self.tau_mean
-            
+
     def dipole_moment(self, bunch, plane, tau0):
         """
         Return the dipole moment of the bunch computed on the same time array 
@@ -148,29 +148,28 @@ class WakePotential(Element):
             Dipole moment of the bunch.
 
         """
-        dipole = np.zeros((self.n_bin - 1,))
+        dipole = np.zeros((self.n_bin - 1, ))
         for i in range(self.n_bin - 1):
             dipole[i] = bunch[plane][self.sorted_index == i].sum()
-        dipole = dipole/self.profile
+        dipole = dipole / self.profile
         dipole[np.isnan(dipole)] = 0
-        
+
         # Add N values to get same size as tau/profile
         if self.n_bin % 2 == 0:
-            N = int(self.n_bin/2)
+            N = int(self.n_bin / 2)
             dipole = np.append(dipole, np.zeros(N))
             dipole = np.insert(dipole, 0, np.zeros(N))
         else:
-            N = int(np.floor(self.n_bin/2))
+            N = int(np.floor(self.n_bin / 2))
             dipole = np.append(dipole, np.zeros(N))
-            dipole = np.insert(dipole, 0, np.zeros(N+1))
-            
+            dipole = np.insert(dipole, 0, np.zeros(N + 1))
+
         # Interpole on tau0 to get the same size as W0
         dipole0 = np.interp(tau0, self.tau, dipole, 0, 0)
-            
+
         setattr(self, "dipole_" + plane, dipole0)
         return dipole0
-    
-    
+
     def prepare_wakefunction(self, wake_type, tau, save_data=True):
         """
         Prepare the wake function of a given wake_type to be used for the wake
@@ -204,41 +203,37 @@ class WakePotential(Element):
         tau0 = np.array(getattr(self.wakefield, wake_type).data.index)
         dtau0 = tau0[1] - tau0[0]
         W0 = np.array(getattr(self.wakefield, wake_type).data["real"])
-        
+
         # Keep only the wake function on the rho window
-        ind = np.all([min(tau[0], 0) < tau0, max(tau[-1], 0) > tau0],
-                     axis=0)
+        ind = np.all([min(tau[0], 0) < tau0, max(tau[-1], 0) > tau0], axis=0)
         tau0 = tau0[ind]
         W0 = W0[ind]
-        
+
         # Check the wake function window for assymetry
         assym = (np.abs(tau0[-1]) - np.abs(tau0[0])) / dtau0
         n_assym = int(np.floor(assym))
         if np.floor(assym) > 1:
-            
+
             # add at head
-            if np.abs(tau0[-1]) >  np.abs(tau0[0]):
-                tau0 = np.arange(tau0[0] - dtau0*n_assym, 
-                                 tau0[-1] + dtau0, 
+            if np.abs(tau0[-1]) > np.abs(tau0[0]):
+                tau0 = np.arange(tau0[0] - dtau0*n_assym, tau0[-1] + dtau0,
                                  dtau0)
                 n_to_add = len(tau0) - len(W0)
                 W0 = np.insert(W0, 0, np.zeros(n_to_add))
-                
+
             # add at tail
-            elif np.abs(tau0[0]) >  np.abs(tau0[-1]):
-                tau0 = np.arange(tau0[0], 
-                                 tau0[-1] + dtau0*(n_assym+1), 
+            elif np.abs(tau0[0]) > np.abs(tau0[-1]):
+                tau0 = np.arange(tau0[0], tau0[-1] + dtau0 * (n_assym+1),
                                  dtau0)
                 n_to_add = len(tau0) - len(W0)
                 W0 = np.insert(W0, 0, np.zeros(n_to_add))
-                
+
         # Check is the wf is shorter than rho then add zeros
         if (tau0[0] > tau[0]) or (tau0[-1] < tau[-1]):
-            n = max(int(np.ceil((tau0[0] - tau[0])/dtau0)),
-                    int(np.ceil((tau[-1] - tau0[-1])/dtau0)))
-            
-            tau0 = np.arange(tau0[0] - dtau0*n, 
-                             tau0[-1] + dtau0*(n+1), 
+            n = max(int(np.ceil((tau0[0] - tau[0]) / dtau0)),
+                    int(np.ceil((tau[-1] - tau0[-1]) / dtau0)))
+
+            tau0 = np.arange(tau0[0] - dtau0*n, tau0[-1] + dtau0 * (n+1),
                              dtau0)
             W0 = np.insert(W0, 0, np.zeros(n))
             n_to_add = len(tau0) - len(W0)
@@ -248,9 +243,9 @@ class WakePotential(Element):
             setattr(self, "tau0_" + wake_type, tau0)
             setattr(self, "dtau0_" + wake_type, dtau0)
             setattr(self, "W0_" + wake_type, W0)
-            
+
         return (tau0, dtau0, W0)
-        
+
     def get_wakepotential(self, bunch, wake_type):
         """
         Return the wake potential computed on the wake function time array 
@@ -270,24 +265,24 @@ class WakePotential(Element):
         """
 
         (tau0, dtau0, W0) = self.prepare_wakefunction(wake_type, self.tau)
-        
+
         profile0 = np.interp(tau0, self.tau, self.rho, 0, 0)
-        
+
         if wake_type == "Wlong" or wake_type == "Wxquad" or wake_type == "Wyquad" or wake_type == "Wxcst" or wake_type == "Wycst":
-            Wp = signal.convolve(profile0, W0*-1, mode='same')*dtau0
+            Wp = signal.convolve(profile0, W0 * -1, mode='same') * dtau0
         elif wake_type == "Wxdip":
             dipole0 = self.dipole_moment(bunch, "x", tau0)
-            Wp = signal.convolve(profile0*dipole0, W0, mode='same')*dtau0
+            Wp = signal.convolve(profile0 * dipole0, W0, mode='same') * dtau0
         elif wake_type == "Wydip":
             dipole0 = self.dipole_moment(bunch, "y", tau0)
-            Wp = signal.convolve(profile0*dipole0, W0, mode='same')*dtau0
+            Wp = signal.convolve(profile0 * dipole0, W0, mode='same') * dtau0
         else:
             raise ValueError("This type of wake is not taken into account.")
-        
-        setattr(self,"profile0_" + wake_type, profile0)
+
+        setattr(self, "profile0_" + wake_type, profile0)
         setattr(self, wake_type, Wp)
         return tau0, Wp
-    
+
     @Element.parallel
     def track(self, bunch):
         """
@@ -300,12 +295,13 @@ class WakePotential(Element):
         bunch : Bunch or Beam object.
         
         """
-        
+
         if len(bunch) != 0:
             self.charge_density(bunch)
             for wake_type in self.types:
                 tau0, Wp = self.get_wakepotential(bunch, wake_type)
-                Wp_interp = np.interp(bunch["tau"], tau0 + self.tau_mean, Wp, 0, 0)
+                Wp_interp = np.interp(bunch["tau"], tau0 + self.tau_mean, Wp,
+                                      0, 0)
                 if wake_type == "Wlong":
                     bunch["delta"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wxdip":
@@ -313,17 +309,20 @@ class WakePotential(Element):
                 elif wake_type == "Wydip":
                     bunch["yp"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wxquad":
-                    bunch["xp"] += (bunch["x"] * Wp_interp * bunch.charge 
-                                    / self.ring.E0)
+                    bunch["xp"] += (bunch["x"] * Wp_interp * bunch.charge /
+                                    self.ring.E0)
                 elif wake_type == "Wyquad":
-                    bunch["yp"] += (bunch["y"] * Wp_interp * bunch.charge 
-                                    / self.ring.E0)
+                    bunch["yp"] += (bunch["y"] * Wp_interp * bunch.charge /
+                                    self.ring.E0)
                 elif wake_type == "Wxcst":
                     bunch["xp"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wycst":
                     bunch["yp"] += Wp_interp * bunch.charge / self.ring.E0
-                
-    def plot_last_wake(self, wake_type, plot_rho=True, plot_dipole=False, 
+
+    def plot_last_wake(self,
+                       wake_type,
+                       plot_rho=True,
+                       plot_dipole=False,
                        plot_wake_function=True):
         """
         Plot the last wake potential of a given type computed during the last
@@ -345,45 +344,56 @@ class WakePotential(Element):
         fig : figure
 
         """
-        
-        labels = {"Wlong" : r"$W_{p,long}$ (V/pC)", 
-                  "Wxdip" : r"$W_{p,x}^{D} (V/pC)$",
-                  "Wydip" : r"$W_{p,y}^{D} (V/pC)$",
-                  "Wxquad" : r"$W_{p,x}^{Q} (V/pC/m)$",
-                  "Wyquad" : r"$W_{p,y}^{Q} (V/pC/m)$",
-                  "Wxcst" : r"$W_{p,x}^{M}$ (V/pC)",
-                  "Wycst" : r"$W_{p,y}^{M}$ (V/pC)",}
-        
+
+        labels = {
+            "Wlong": r"$W_{p,long}$ (V/pC)",
+            "Wxdip": r"$W_{p,x}^{D} (V/pC)$",
+            "Wydip": r"$W_{p,y}^{D} (V/pC)$",
+            "Wxquad": r"$W_{p,x}^{Q} (V/pC/m)$",
+            "Wyquad": r"$W_{p,y}^{Q} (V/pC/m)$",
+            "Wxcst": r"$W_{p,x}^{M}$ (V/pC)",
+            "Wycst": r"$W_{p,y}^{M}$ (V/pC)",
+        }
+
         Wp = getattr(self, wake_type)
         tau0 = getattr(self, "tau0_" + wake_type)
-        
+
         fig, ax = plt.subplots()
-        ax.plot(tau0*1e12, Wp*1e-12, label=labels[wake_type])
+        ax.plot(tau0 * 1e12, Wp * 1e-12, label=labels[wake_type])
         ax.set_xlabel("$\\tau$ (ps)")
         ax.set_ylabel(labels[wake_type])
 
         if plot_rho is True:
-            profile0 = getattr(self,"profile0_" + wake_type)
-            profile_rescaled = profile0/max(profile0)*max(np.abs(Wp))
-            rho_rescaled = self.rho/max(self.rho)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, profile_rescaled*1e-12, label=r"$\rho$ interpolated (a.u.)")
-            ax.plot((self.tau + self.tau_mean)*1e12, rho_rescaled*1e-12, label=r"$\rho$ (a.u.)", linestyle='dashed')
+            profile0 = getattr(self, "profile0_" + wake_type)
+            profile_rescaled = profile0 / max(profile0) * max(np.abs(Wp))
+            rho_rescaled = self.rho / max(self.rho) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    profile_rescaled * 1e-12,
+                    label=r"$\rho$ interpolated (a.u.)")
+            ax.plot((self.tau + self.tau_mean) * 1e12,
+                    rho_rescaled * 1e-12,
+                    label=r"$\rho$ (a.u.)",
+                    linestyle='dashed')
             plt.legend()
-            
+
         if plot_wake_function is True:
             W0 = getattr(self, "W0_" + wake_type)
-            W0_rescaled = W0/max(W0)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, W0_rescaled*1e-12, label=r"$W_{function}$ (a.u.)")
+            W0_rescaled = W0 / max(W0) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    W0_rescaled * 1e-12,
+                    label=r"$W_{function}$ (a.u.)")
             plt.legend()
-            
+
         if plot_dipole is True:
             dipole = getattr(self, "dipole_" + wake_type[1])
-            dipole_rescaled = dipole/max(dipole)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, dipole_rescaled*1e-12, label=r"Dipole moment (a.u.)")
+            dipole_rescaled = dipole / max(dipole) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    dipole_rescaled * 1e-12,
+                    label=r"Dipole moment (a.u.)")
             plt.legend()
-            
+
         return fig
-    
+
     def get_gaussian_wakepotential(self, sigma, wake_type, dipole=1e-3):
         """
         Return the wake potential computed using a perfect gaussian profile.
@@ -410,26 +420,31 @@ class WakePotential(Element):
             Dipole moment.
 
         """
-        
-        tau = np.linspace(-10*sigma,10*sigma, int(1e3))
+
+        tau = np.linspace(-10 * sigma, 10 * sigma, int(1e3))
         (tau0, dtau0, W0) = self.prepare_wakefunction(wake_type, tau, False)
-        
+
         profile0 = gaussian_bunch(tau0, sigma)
-        dipole0 = np.ones_like(profile0)*dipole
-        
+        dipole0 = np.ones_like(profile0) * dipole
+
         if wake_type == "Wlong" or wake_type == "Wxquad" or wake_type == "Wyquad":
-            Wp = signal.convolve(profile0, W0*-1, mode='same')*dtau0
+            Wp = signal.convolve(profile0, W0 * -1, mode='same') * dtau0
         elif wake_type == "Wxdip":
-            Wp = signal.convolve(profile0*dipole0, W0, mode='same')*dtau0
+            Wp = signal.convolve(profile0 * dipole0, W0, mode='same') * dtau0
         elif wake_type == "Wydip":
-            Wp = signal.convolve(profile0*dipole0, W0, mode='same')*dtau0
+            Wp = signal.convolve(profile0 * dipole0, W0, mode='same') * dtau0
         else:
             raise ValueError("This type of wake is not taken into account.")
 
         return tau0, W0, Wp, profile0, dipole0
-        
-    def plot_gaussian_wake(self, sigma, wake_type, dipole=1e-3, plot_rho=True, 
-                           plot_dipole=False, plot_wake_function=True):
+
+    def plot_gaussian_wake(self,
+                           sigma,
+                           wake_type,
+                           dipole=1e-3,
+                           plot_rho=True,
+                           plot_dipole=False,
+                           plot_wake_function=True):
         """
         Plot the wake potential of a given type for a perfect gaussian bunch.
 
@@ -453,37 +468,46 @@ class WakePotential(Element):
         fig : figure
 
         """
-        
-        labels = {"Wlong" : r"$W_{p,long}$ (V/pC)", 
-                  "Wxdip" : r"$W_{p,x}^{D} (V/pC)$",
-                  "Wydip" : r"$W_{p,y}^{D} (V/pC)$",
-                  "Wxquad" : r"$W_{p,x}^{Q} (V/pC/m)$",
-                  "Wyquad" : r"$W_{p,y}^{Q} (V/pC/m)$"}
-        
-        tau0, W0, Wp, profile0, dipole0 = self.get_gaussian_wakepotential(sigma, wake_type, dipole)
-        
+
+        labels = {
+            "Wlong": r"$W_{p,long}$ (V/pC)",
+            "Wxdip": r"$W_{p,x}^{D} (V/pC)$",
+            "Wydip": r"$W_{p,y}^{D} (V/pC)$",
+            "Wxquad": r"$W_{p,x}^{Q} (V/pC/m)$",
+            "Wyquad": r"$W_{p,y}^{Q} (V/pC/m)$"
+        }
+
+        tau0, W0, Wp, profile0, dipole0 = self.get_gaussian_wakepotential(
+            sigma, wake_type, dipole)
+
         fig, ax = plt.subplots()
-        ax.plot(tau0*1e12, Wp*1e-12, label=labels[wake_type])
+        ax.plot(tau0 * 1e12, Wp * 1e-12, label=labels[wake_type])
         ax.set_xlabel("$\\tau$ (ps)")
         ax.set_ylabel(labels[wake_type])
 
         if plot_rho is True:
-            profile_rescaled = profile0/max(profile0)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, profile_rescaled*1e-12, label=r"$\rho$ (a.u.)")
+            profile_rescaled = profile0 / max(profile0) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    profile_rescaled * 1e-12,
+                    label=r"$\rho$ (a.u.)")
             plt.legend()
-            
+
         if plot_wake_function is True:
-            W0_rescaled = W0/max(W0)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, W0_rescaled*1e-12, label=r"$W_{function}$ (a.u.)")
+            W0_rescaled = W0 / max(W0) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    W0_rescaled * 1e-12,
+                    label=r"$W_{function}$ (a.u.)")
             plt.legend()
-            
+
         if plot_dipole is True:
-            dipole_rescaled = dipole0/max(dipole0)*max(np.abs(Wp))
-            ax.plot(tau0*1e12, dipole_rescaled*1e-12, label=r"Dipole moment (a.u.)")
+            dipole_rescaled = dipole0 / max(dipole0) * max(np.abs(Wp))
+            ax.plot(tau0 * 1e12,
+                    dipole_rescaled * 1e-12,
+                    label=r"Dipole moment (a.u.)")
             plt.legend()
-            
+
         return fig
-    
+
     def reference_loss(self, bunch):
         """
         Calculate the loss factor and kick factor from the wake potential and 
@@ -508,31 +532,31 @@ class WakePotential(Element):
         for wake_type in self.types:
             tau0, Wp = self.get_wakepotential(bunch, wake_type)
             profile0 = getattr(self, "profile0_" + wake_type)
-            factorTD = np.trapz(Wp*profile0, tau0)
-            
+            factorTD = np.trapz(Wp * profile0, tau0)
+
             if wake_type == "Wlong":
                 factorTD *= -1
             if wake_type == "Wxdip":
                 factorTD /= bunch["x"].mean()
             if wake_type == "Wydip":
                 factorTD /= bunch["y"].mean()
-            
+
             Z = getattr(self.wakefield, "Z" + wake_type[1:])
             sigma = bunch['tau'].std()
             factorFD = Z.loss_factor(sigma)
-            
+
             loss.append(factorTD)
             loss_0.append(factorFD)
-            delta_loss.append( (factorTD - factorFD) / factorFD *100 )
+            delta_loss.append((factorTD-factorFD) / factorFD * 100)
             if wake_type == "Wlong":
                 index.append("Wlong [V/C]")
             else:
                 index.append(wake_type + " [V/C/m]")
-            
+
             column = ['TD factor', 'FD factor', 'Relative error [%]']
-            
-        loss_data = pd.DataFrame(np.array([loss, loss_0, delta_loss]).T, 
-                                 columns=column, 
+
+        loss_data = pd.DataFrame(np.array([loss, loss_0, delta_loss]).T,
+                                 columns=column,
                                  index=index)
         return loss_data
 
@@ -547,11 +571,12 @@ class WakePotential(Element):
         """
         for wake_type in self.types:
             idx = getattr(self.wakefield, wake_type).data.index
-            diff = idx[1:]-idx[:-1]
+            diff = idx[1:] - idx[:-1]
             result = np.all(np.isclose(diff, diff[0], atol=1e-15))
             if result is False:
-                raise ValueError("The wake function must be uniformly sampled.")
-    
+                raise ValueError(
+                    "The wake function must be uniformly sampled.")
+
     def reduce_sampling(self, factor):
         """
         Reduce wake function samping by an integer factor.
@@ -565,10 +590,12 @@ class WakePotential(Element):
         """
         for wake_type in self.types:
             idx = getattr(self.wakefield, wake_type).data.index[::factor]
-            getattr(self.wakefield, wake_type).data = getattr(self.wakefield, wake_type).data.loc[idx]
+            getattr(self.wakefield,
+                    wake_type).data = getattr(self.wakefield,
+                                              wake_type).data.loc[idx]
         self.check_sampling()
-    
-    
+
+
 class LongRangeResistiveWall(Element):
     """
     Element to deal with multi-bunch and multi-turn wakes from resistive wall 
@@ -611,8 +638,16 @@ class LongRangeResistiveWall(Element):
     [1] : Skripka, Galina, et al. "Simultaneous computation of intrabunch and 
     interbunch collective beam motions in storage rings." NIM.A (2016).
     """
-    def __init__(self, ring, beam, length, rho, radius, 
-                 types=["Wlong","Wxdip","Wydip"], nt=50, x3=None, y3=None):
+    def __init__(self,
+                 ring,
+                 beam,
+                 length,
+                 rho,
+                 radius,
+                 types=["Wlong", "Wxdip", "Wydip"],
+                 nt=50,
+                 x3=None,
+                 y3=None):
         # parameters
         self.ring = ring
         self.length = length
@@ -623,7 +658,7 @@ class LongRangeResistiveWall(Element):
             self.types = [types]
         elif isinstance(types, list):
             self.types = types
-        
+
         # effective radius for RW
         self.radius = radius
         if x3 is not None:
@@ -634,21 +669,21 @@ class LongRangeResistiveWall(Element):
             self.y3 = y3
         else:
             self.y3 = radius
-        
+
         # constants
-        self.Z0 = mu_0*c
-        self.t0 = (2*self.rho*self.radius**2 / self.Z0)**(1/3) / c
-        
+        self.Z0 = mu_0 * c
+        self.t0 = (2 * self.rho * self.radius**2 / self.Z0)**(1 / 3) / c
+
         # check approximation
-        if 20*self.t0 > ring.T1:
+        if 20 * self.t0 > ring.T1:
             raise ValueError("The approximated wake functions are not valid.")
-        
+
         # init tables
-        self.tau = np.ones((self.nb,self.nt))*1e100
-        self.x = np.zeros((self.nb,self.nt))
-        self.y = np.zeros((self.nb,self.nt))
-        self.charge = np.zeros((self.nb,self.nt))
-        
+        self.tau = np.ones((self.nb, self.nt)) * 1e100
+        self.x = np.zeros((self.nb, self.nt))
+        self.y = np.zeros((self.nb, self.nt))
+        self.charge = np.zeros((self.nb, self.nt))
+
     def Wlong(self, t):
         """
         Approxmiate expression for the longitudinal resistive wall wake 
@@ -665,10 +700,11 @@ class LongRangeResistiveWall(Element):
             Wake function in [V/C].
 
         """
-        wl = (1/(4*pi * self.radius) * np.sqrt(self.Z0 * self.rho / (c * pi) ) / 
-              t**(3/2) ) * self.length * -1
+        wl = (1 / (4 * pi * self.radius) * np.sqrt(self.Z0 * self.rho /
+                                                   (c*pi)) /
+              t**(3 / 2)) * self.length * -1
         return wl
-    
+
     def Wdip(self, t, plane):
         """
         Approxmiate expression for the transverse resistive wall wake 
@@ -693,11 +729,11 @@ class LongRangeResistiveWall(Element):
             r3 = self.y3
         else:
             raise ValueError()
-            
-        wdip = (1 / (pi * r3**3) * np.sqrt(self.Z0 * c * self.rho / pi) / 
-                t**(1/2) * self.length)
+
+        wdip = (1 / (pi * r3**3) * np.sqrt(self.Z0 * c * self.rho / pi) /
+                t**(1 / 2) * self.length)
         return wdip
-    
+
     def update_tables(self, beam):
         """
         Update tables.
@@ -724,25 +760,27 @@ class LongRangeResistiveWall(Element):
         self.x = np.roll(self.x, shift=1, axis=1)
         self.y = np.roll(self.y, shift=1, axis=1)
         self.charge = np.roll(self.charge, shift=1, axis=1)
-        
+
         # update tables
         if beam.mpi_switch:
             beam.mpi.share_means(beam)
             # negative sign as when bunch 0 is tracked, the others are not yet passed
-            self.tau[:,0] = beam.mpi.mean_all[:,4] - beam.bunch_index*self.ring.T1
-            self.x[:,0] = beam.mpi.mean_all[:,0]
-            self.y[:,0] = beam.mpi.mean_all[:,2]
-            self.charge[:,0] = beam.mpi.charge_all
+            self.tau[:,
+                     0] = beam.mpi.mean_all[:,
+                                            4] - beam.bunch_index * self.ring.T1
+            self.x[:, 0] = beam.mpi.mean_all[:, 0]
+            self.y[:, 0] = beam.mpi.mean_all[:, 2]
+            self.charge[:, 0] = beam.mpi.charge_all
         else:
             mean_all = beam.bunch_mean
-            charge_all =  beam.bunch_charge
+            charge_all = beam.bunch_charge
             # negative sign as when bunch 0 is tracked, the others are not yet passed
-            self.tau[:,0] = mean_all[4, beam.filling_pattern] - beam.bunch_index*self.ring.T1
-            self.x[:,0] = mean_all[0, beam.filling_pattern]
-            self.y[:,0] = mean_all[2, beam.filling_pattern]
-            self.charge[:,0] = charge_all[beam.filling_pattern]
-        
-    
+            self.tau[:, 0] = mean_all[
+                4, beam.filling_pattern] - beam.bunch_index * self.ring.T1
+            self.x[:, 0] = mean_all[0, beam.filling_pattern]
+            self.y[:, 0] = mean_all[2, beam.filling_pattern]
+            self.charge[:, 0] = charge_all[beam.filling_pattern]
+
     def get_kick(self, rank, wake_type):
         """
         Compute the wake kick to apply.
@@ -765,20 +803,22 @@ class LongRangeResistiveWall(Element):
             for i in range(self.nb):
                 if (j == 0) and (rank <= i):
                     continue
-                deltaT = self.tau[i,j] - self.tau[rank, 0]
+                deltaT = self.tau[i, j] - self.tau[rank, 0]
                 if wake_type == "Wlong":
-                    sum_kick += self.Wlong(deltaT) * self.charge[i,j]
+                    sum_kick += self.Wlong(deltaT) * self.charge[i, j]
                 elif wake_type == "Wxdip":
-                    sum_kick += self.Wdip(deltaT, "x") * self.charge[i,j] * self.x[i,j]
+                    sum_kick += self.Wdip(
+                        deltaT, "x") * self.charge[i, j] * self.x[i, j]
                 elif wake_type == "Wydip":
-                    sum_kick += self.Wdip(deltaT, "y") * self.charge[i,j] * self.y[i,j]
+                    sum_kick += self.Wdip(
+                        deltaT, "y") * self.charge[i, j] * self.y[i, j]
                 elif wake_type == "Wxquad":
                     raise NotImplementedError()
                 elif wake_type == "Wyquad":
                     raise NotImplementedError()
-                    
+
         return sum_kick
-    
+
     def track_bunch(self, bunch, rank):
         """
         Track a bunch.
@@ -808,7 +848,7 @@ class LongRangeResistiveWall(Element):
                 bunch["xp"] += (bunch["x"] * kick / self.ring.E0)
             elif wake_type == "Wyquad":
                 bunch["yp"] += (bunch["y"] * kick / self.ring.E0)
-    
+
     def track(self, beam):
         """
         Track a beam.
@@ -823,14 +863,12 @@ class LongRangeResistiveWall(Element):
 
         """
         self.update_tables(beam)
-        
+
         if beam.mpi_switch:
             rank = beam.mpi.rank
-            bunch_index = beam.mpi.bunch_num # Number of the tracked bunch in this processor
+            bunch_index = beam.mpi.bunch_num  # Number of the tracked bunch in this processor
             bunch = beam[bunch_index]
             self.track_bunch(bunch, rank)
         else:
             for rank, bunch in enumerate(beam.not_empty):
                 self.track_bunch(bunch, rank)
-
-    
\ No newline at end of file
diff --git a/mbtrack2/utilities/__init__.py b/mbtrack2/utilities/__init__.py
index 1a0a7e25469d32d70c11e1c8308cf5fc2c22f47b..ebc678ec9821f1499ec138219f04df6159c73f9d 100644
--- a/mbtrack2/utilities/__init__.py
+++ b/mbtrack2/utilities/__init__.py
@@ -1,17 +1,22 @@
 # -*- coding: utf-8 -*-
-from mbtrack2.utilities.read_impedance import (read_CST,
-                                               read_IW2D,
-                                               read_IW2D_folder,
-                                               read_ABCI,
-                                               read_ECHO2D)
-from mbtrack2.utilities.misc import (effective_impedance,
-                                     yokoya_elliptic,
-                                     beam_loss_factor,
-                                     double_sided_impedance)
-from mbtrack2.utilities.spectrum import (spectral_density,
-                                         gaussian_bunch_spectrum,
-                                         gaussian_bunch,
-                                         beam_spectrum)
-from mbtrack2.utilities.optics import (Optics,
-                                       PhysicalModel)
-from mbtrack2.utilities.beamloading import BeamLoadingEquilibrium
\ No newline at end of file
+from mbtrack2.utilities.beamloading import BeamLoadingEquilibrium
+from mbtrack2.utilities.misc import (
+    beam_loss_factor,
+    double_sided_impedance,
+    effective_impedance,
+    yokoya_elliptic,
+)
+from mbtrack2.utilities.optics import Optics, PhysicalModel
+from mbtrack2.utilities.read_impedance import (
+    read_ABCI,
+    read_CST,
+    read_ECHO2D,
+    read_IW2D,
+    read_IW2D_folder,
+)
+from mbtrack2.utilities.spectrum import (
+    beam_spectrum,
+    gaussian_bunch,
+    gaussian_bunch_spectrum,
+    spectral_density,
+)
diff --git a/mbtrack2/utilities/beamloading.py b/mbtrack2/utilities/beamloading.py
index 84a23a18645ba3e0bcbd8e50a2797ae2c3f3c115..1d208c2a7625349b31b2d254afbe738a0b57180e 100644
--- a/mbtrack2/utilities/beamloading.py
+++ b/mbtrack2/utilities/beamloading.py
@@ -3,11 +3,12 @@
 Module where the BeamLoadingEquilibrium class is defined.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
-from scipy.optimize import root
+import numpy as np
 from scipy.constants import c
 from scipy.integrate import quad
+from scipy.optimize import root
+
 
 class BeamLoadingEquilibrium():
     """Class used to compute beam equilibrium profile for a given storage ring 
@@ -30,21 +31,26 @@ class BeamLoadingEquilibrium():
     B1 : lower intergration boundary
     B2 : upper intergration boundary
     """
-
-    def __init__(
-                self, ring, cavity_list, I0, auto_set_MC_theta=False, F=None,
-                PHI=None, B1=-0.2, B2=0.2):
+    def __init__(self,
+                 ring,
+                 cavity_list,
+                 I0,
+                 auto_set_MC_theta=False,
+                 F=None,
+                 PHI=None,
+                 B1=-0.2,
+                 B2=0.2):
         self.ring = ring
         self.cavity_list = cavity_list
         self.I0 = I0
         self.n_cavity = len(cavity_list)
         self.auto_set_MC_theta = auto_set_MC_theta
         if F is None:
-            self.F = np.ones((self.n_cavity,))
+            self.F = np.ones((self.n_cavity, ))
         else:
             self.F = F
         if PHI is None:
-            self.PHI = np.zeros((self.n_cavity,))
+            self.PHI = np.zeros((self.n_cavity, ))
         else:
             self.PHI = PHI
         self.B1 = B1
@@ -53,64 +59,64 @@ class BeamLoadingEquilibrium():
         self.__version__ = "1.0"
 
         # Define constants for scaled potential u(z)
-        self.u0 = self.ring.U0 / (
-            self.ring.ac * self.ring.sigma_delta**2
-            * self.ring.E0 * self.ring.L)
-        self.ug = np.zeros((self.n_cavity,))
-        self.ub = np.zeros((self.n_cavity,))
+        self.u0 = self.ring.U0 / (self.ring.ac * self.ring.sigma_delta**2 *
+                                  self.ring.E0 * self.ring.L)
+        self.ug = np.zeros((self.n_cavity, ))
+        self.ub = np.zeros((self.n_cavity, ))
         self.update_potentials()
-            
+
     def update_potentials(self):
         """Update potentials with cavity and ring data."""
         for i in range(self.n_cavity):
             cavity = self.cavity_list[i]
-            self.ug[i] = cavity.Vg / (
-                self.ring.ac * self.ring.sigma_delta ** 2 *
-                self.ring.E0 * self.ring.L * self.ring.k1 *
-                cavity.m)
+            self.ug[i] = cavity.Vg / (self.ring.ac * self.ring.sigma_delta**2 *
+                                      self.ring.E0 * self.ring.L *
+                                      self.ring.k1 * cavity.m)
             self.ub[i] = 2 * self.I0 * cavity.Rs / (
-                self.ring.ac * self.ring.sigma_delta**2 *
-                self.ring.E0 * self.ring.L * self.ring.k1 *
-                cavity.m * (1 + cavity.beta))
-        
+                self.ring.ac * self.ring.sigma_delta**2 * self.ring.E0 *
+                self.ring.L * self.ring.k1 * cavity.m * (1 + cavity.beta))
+
     def energy_balance(self):
         """Return energy balance for the synchronous particle
         (z = 0 ,delta = 0)."""
         delta = self.ring.U0
         for i in range(self.n_cavity):
             cavity = self.cavity_list[i]
-            delta += cavity.Vb(self.I0) * self.F[i] * np.cos(cavity.psi - self.PHI[i])
+            delta += cavity.Vb(
+                self.I0) * self.F[i] * np.cos(cavity.psi - self.PHI[i])
             delta -= cavity.Vg * np.cos(cavity.theta_g)
         return delta
-    
+
     def center_of_mass(self):
         """Return center of mass position in [s]"""
         z0 = np.linspace(self.B1, self.B2, 1000)
         rho = self.rho(z0)
         CM = np.average(z0, weights=rho)
-        return CM/c
+        return CM / c
 
     def u(self, z):
         """Scaled potential u(z)"""
         pot = self.u0 * z
         for i in range(self.n_cavity):
             cavity = self.cavity_list[i]
-            pot += - self.ug[i] * (
-                np.sin(self.ring.k1 * cavity.m * z + cavity.theta_g)
-                - np.sin(cavity.theta_g))
+            pot += -self.ug[i] * (
+                np.sin(self.ring.k1 * cavity.m * z + cavity.theta_g) -
+                np.sin(cavity.theta_g))
             pot += self.ub[i] * self.F[i] * np.cos(cavity.psi) * (
-                np.sin(self.ring.k1 * cavity.m * z
-                       + cavity.psi - self.PHI[i])
+                np.sin(self.ring.k1 * cavity.m * z + cavity.psi - self.PHI[i])
                 - np.sin(cavity.psi - self.PHI[i]))
         return pot
-    
+
     def du_dz(self, z):
         """Partial derivative of the scaled potential u(z) by z"""
         pot = self.u0
         for i in range(self.n_cavity):
             cavity = self.cavity_list[i]
-            pot += - self.ug[i] * self.ring.k1 * cavity.m * np.cos(self.ring.k1 * cavity.m * z + cavity.theta_g)
-            pot += self.ub[i] * self.F[i] * self.ring.k1 * cavity.m * np.cos(cavity.psi) * np.cos(self.ring.k1 * cavity.m * z + cavity.psi - self.PHI[i])
+            pot += -self.ug[i] * self.ring.k1 * cavity.m * np.cos(
+                self.ring.k1 * cavity.m * z + cavity.theta_g)
+            pot += self.ub[i] * self.F[i] * self.ring.k1 * cavity.m * np.cos(
+                cavity.psi) * np.cos(self.ring.k1 * cavity.m * z + cavity.psi -
+                                     self.PHI[i])
         return pot
 
     def uexp(self, z):
@@ -146,26 +152,34 @@ class BeamLoadingEquilibrium():
 
         # Compute system
         if self.auto_set_MC_theta:
-            res = np.zeros((self.n_cavity * 2 + 1,))
+            res = np.zeros((self.n_cavity * 2 + 1, ))
             for i in range(self.n_cavity):
                 cavity = self.cavity_list[i]
-                res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
-                    lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
-                res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
-                    lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
+                res[2 *
+                    i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
+                        lambda y: self.uexp(y),
+                        lambda y: np.cos(self.ring.k1 * cavity.m * y))
+                res[2*i +
+                    1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
+                        lambda y: self.uexp(y),
+                        lambda y: np.sin(self.ring.k1 * cavity.m * y))
             # Factor 1e-8 or 1e12 for better convergence
             if CM is True:
                 res[self.n_cavity * 2] = self.center_of_mass() * 1e12
             else:
                 res[self.n_cavity * 2] = self.energy_balance() * 1e-8
         else:
-            res = np.zeros((self.n_cavity * 2,))
+            res = np.zeros((self.n_cavity * 2, ))
             for i in range(self.n_cavity):
                 cavity = self.cavity_list[i]
-                res[2 * i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
-                    lambda y: self.uexp(y), lambda y: np.cos(self.ring.k1 * cavity.m * y))
-                res[2 * i + 1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
-                    lambda y: self.uexp(y), lambda y: np.sin(self.ring.k1 * cavity.m * y))
+                res[2 *
+                    i] = self.F[i] * np.cos(self.PHI[i]) - self.integrate_func(
+                        lambda y: self.uexp(y),
+                        lambda y: np.cos(self.ring.k1 * cavity.m * y))
+                res[2*i +
+                    1] = self.F[i] * np.sin(self.PHI[i]) - self.integrate_func(
+                        lambda y: self.uexp(y),
+                        lambda y: np.sin(self.ring.k1 * cavity.m * y))
         return res
 
     def rho(self, z):
@@ -180,10 +194,10 @@ class BeamLoadingEquilibrium():
         if z2 is None:
             z2 = self.B2
         z0 = np.linspace(z1, z2, 1000)
-        plt.plot(z0/c*1e12, self.rho(z0))
+        plt.plot(z0 / c * 1e12, self.rho(z0))
         plt.xlabel(r"$\tau$ [ps]")
         plt.title("Equilibrium bunch profile")
-        
+
     def voltage(self, z):
         """Return the RF system total voltage at position z"""
         Vtot = 0
@@ -191,7 +205,7 @@ class BeamLoadingEquilibrium():
             cavity = self.cavity_list[i]
             Vtot += cavity.VRF(z, self.I0, self.F[i], self.PHI[i])
         return Vtot
-    
+
     def dV(self, z):
         """Return derivative of the RF system total voltage at position z"""
         Vtot = 0
@@ -199,7 +213,7 @@ class BeamLoadingEquilibrium():
             cavity = self.cavity_list[i]
             Vtot += cavity.dVRF(z, self.I0, self.F[i], self.PHI[i])
         return Vtot
-    
+
     def ddV(self, z):
         """Return the second derivative of the RF system total voltage at position z"""
         Vtot = 0
@@ -207,7 +221,7 @@ class BeamLoadingEquilibrium():
             cavity = self.cavity_list[i]
             Vtot += cavity.ddVRF(z, self.I0, self.F[i], self.PHI[i])
         return Vtot
-    
+
     def deltaVRF(self, z):
         """Return the generator voltage minus beam loading voltage of the total RF system at position z"""
         Vtot = 0
@@ -215,7 +229,7 @@ class BeamLoadingEquilibrium():
             cavity = self.cavity_list[i]
             Vtot += cavity.deltaVRF(z, self.I0, self.F[i], self.PHI[i])
         return Vtot
-    
+
     def plot_dV(self, z1=None, z2=None):
         """Plot the derivative of RF system total voltage between z1 and z2"""
         if z1 is None:
@@ -226,7 +240,7 @@ class BeamLoadingEquilibrium():
         plt.plot(z0, self.dV(z0))
         plt.xlabel("z [m]")
         plt.ylabel("Total RF voltage (V)")
-        
+
     def plot_voltage(self, z1=None, z2=None):
         """Plot the RF system total voltage between z1 and z2"""
         if z1 is None:
@@ -250,8 +264,13 @@ class BeamLoadingEquilibrium():
         variance = np.average((z0 - average)**2, weights=values)
         return np.sqrt(variance)
 
-    def beam_equilibrium(self, x0=None, tol=1e-4, method='hybr', options=None, 
-                         plot = False, CM=True):
+    def beam_equilibrium(self,
+                         x0=None,
+                         tol=1e-4,
+                         method='hybr',
+                         options=None,
+                         plot=False,
+                         CM=True):
         """Solve system of non-linear equation to find the form factors F
         and PHI at equilibrum. Can be used to compute the equilibrium bunch
         profile.
@@ -277,12 +296,15 @@ class BeamLoadingEquilibrium():
 
         if CM:
             print("The initial center of mass offset is " +
-                  str(self.center_of_mass()*1e12) + " ps")
+                  str(self.center_of_mass() * 1e12) + " ps")
         else:
             print("The initial energy balance is " +
                   str(self.energy_balance()) + " eV")
 
-        sol = root(lambda x : self.to_solve(x, CM), x0, tol=tol, method=method, 
+        sol = root(lambda x: self.to_solve(x, CM),
+                   x0,
+                   tol=tol,
+                   method=method,
                    options=options)
 
         # Update values of F, PHI and theta_g
@@ -298,17 +320,17 @@ class BeamLoadingEquilibrium():
 
         if CM:
             print("The final center of mass offset is " +
-                  str(self.center_of_mass()*1e12) + " ps")
+                  str(self.center_of_mass() * 1e12) + " ps")
         else:
-            print("The final energy balance is " +
-                  str(self.energy_balance()) + " eV")
+            print("The final energy balance is " + str(self.energy_balance()) +
+                  " eV")
         print("The algorithm has converged: " + str(sol.success))
-        
+
         if plot:
             self.plot_rho(self.B1 / 4, self.B2 / 4)
 
         return sol
-    
+
     def PTBL_threshold(self, I0, m=None, MC_index=0, HC_index=1):
         """
         Compute the periodic transient beam loading (PTBL) instability 
@@ -352,26 +374,27 @@ class BeamLoadingEquilibrium():
         """
         if m is None:
             m = self.ring.h
-            
+
         MC = self.cavity_list[MC_index]
         HC = self.cavity_list[HC_index]
-        a = np.exp(-HC.wr*self.ring.T0/(2*HC.Q))
-        theta = HC.detune*2*np.pi*self.ring.T0
-        dtheta = np.arcsin((1-a)*np.cos(theta/2)/(np.sqrt(1+a**2-2*a*np.cos(theta))))
-        
+        a = np.exp(-HC.wr * self.ring.T0 / (2 * HC.Q))
+        theta = HC.detune * 2 * np.pi * self.ring.T0
+        dtheta = np.arcsin((1-a) * np.cos(theta / 2) /
+                           (np.sqrt(1 + a**2 - 2 * a * np.cos(theta))))
+
         k = np.arange(1, m)
-        d_k = np.exp(-1*HC.wr*self.ring.T0*(k-1)/(2*HC.Q*m))
-        theta_k = (theta/2 + dtheta - (k-1)/m * theta)
-        eps_k = 1 - np.sin(np.pi/2 - k*2*np.pi/m)
-                   
+        d_k = np.exp(-1 * HC.wr * self.ring.T0 * (k-1) / (2 * HC.Q * m))
+        theta_k = (theta/2 + dtheta - (k-1) / m * theta)
+        eps_k = 1 - np.sin(np.pi / 2 - k * 2 * np.pi / m)
+
         num = np.sum(eps_k * d_k * np.cos(theta_k))
-        f = num / ( m * np.sqrt(1 + a**2 - 2*a*np.cos(theta)) )
-        
-        eta = (2 * np.pi * HC.m**2 * self.F[HC_index] * self.ring.h * I0 * 
-               HC.Rs / HC.Q * f / 
-               ( MC.Vc * np.sin(MC.theta - self.PHI[HC_index] / HC.m) ) )
-        
-        RQth = (MC.Vc * np.sin(MC.theta - self.PHI[HC_index] / HC.m) / 
+        f = num / (m * np.sqrt(1 + a**2 - 2 * a * np.cos(theta)))
+
+        eta = (2 * np.pi * HC.m**2 * self.F[HC_index] * self.ring.h * I0 *
+               HC.Rs / HC.Q * f /
+               (MC.Vc * np.sin(MC.theta - self.PHI[HC_index] / HC.m)))
+
+        RQth = (MC.Vc * np.sin(MC.theta - self.PHI[HC_index] / HC.m) /
                 (2 * np.pi * HC.m**2 * self.F[1] * self.ring.h * I0 * f))
-        
-        return (eta, RQth, f)
\ No newline at end of file
+
+        return (eta, RQth, f)
diff --git a/mbtrack2/utilities/misc.py b/mbtrack2/utilities/misc.py
index 7b06a1af09dd7a8f50d7fa955fff06316745e3e5..ec7d4ff9c6c7def41d54fa9965b882b927f88c64 100644
--- a/mbtrack2/utilities/misc.py
+++ b/mbtrack2/utilities/misc.py
@@ -3,15 +3,24 @@
 This module defines miscellaneous utilities functions.
 """
 
-import pandas as pd
-import numpy as np
 from pathlib import Path
+
+import numpy as np
+import pandas as pd
 from scipy.interpolate import interp1d
+
 from mbtrack2.impedance.wakefield import Impedance
 from mbtrack2.utilities.spectrum import spectral_density
 
 
-def effective_impedance(ring, imp, m, mu, sigma, M, tuneS, xi=None,
+def effective_impedance(ring,
+                        imp,
+                        m,
+                        mu,
+                        sigma,
+                        M,
+                        tuneS,
+                        xi=None,
                         mode="Hermite"):
     """
     Compute the effective (longitudinal or transverse) impedance. 
@@ -58,40 +67,40 @@ def effective_impedance(ring, imp, m, mu, sigma, M, tuneS, xi=None,
         double_sided_impedance(imp)
 
     if mode in ["Hermite", "Legendre", "Sinusoidal", "Sacherer", "Chebyshev"]:
+
         def h(f):
-            return spectral_density(frequency=f, sigma=sigma, m=m,
-                                    mode=mode)
+            return spectral_density(frequency=f, sigma=sigma, m=m, mode=mode)
     else:
         raise NotImplementedError("Not implemanted yet.")
 
     if imp.component_type == "long":
-        pmax = fmax/(ring.f0 * M) - 1
-        pmin = fmin/(ring.f0 * M) + 1
-        p = np.arange(pmin, pmax+1)
+        pmax = fmax / (ring.f0 * M) - 1
+        pmin = fmin / (ring.f0 * M) + 1
+        p = np.arange(pmin, pmax + 1)
 
-        fp = ring.f0*(p*M + mu + m*tuneS)
+        fp = ring.f0 * (p*M + mu + m*tuneS)
         fp = fp[np.nonzero(fp)]  # Avoid division by 0
-        num = np.sum(imp(fp) * h(fp) / (fp*2*np.pi))
+        num = np.sum(imp(fp) * h(fp) / (fp * 2 * np.pi))
         den = np.sum(h(fp))
-        Zeff = num/den
+        Zeff = num / den
 
     elif imp.component_type == "xdip" or imp.component_type == "ydip":
         if imp.component_type == "xdip":
-            tuneXY = ring.tune[0]-np.floor(ring.tune[0])
+            tuneXY = ring.tune[0] - np.floor(ring.tune[0])
             if xi is None:
                 xi = ring.chro[0]
         elif imp.component_type == "ydip":
-            tuneXY = ring.tune[1]-np.floor(ring.tune[1])
+            tuneXY = ring.tune[1] - np.floor(ring.tune[1])
             if xi is None:
                 xi = ring.chro[1]
-        pmax = fmax/(ring.f0 * M) - 1
-        pmin = fmin/(ring.f0 * M) + 1
-        p = np.arange(pmin, pmax+1)
-        fp = ring.f0*(p*M + mu + tuneXY + m*tuneS)
-        f_xi = xi/ring.eta()*ring.f0
+        pmax = fmax / (ring.f0 * M) - 1
+        pmin = fmin / (ring.f0 * M) + 1
+        p = np.arange(pmin, pmax + 1)
+        fp = ring.f0 * (p*M + mu + tuneXY + m*tuneS)
+        f_xi = xi / ring.eta() * ring.f0
         num = np.sum(imp(fp) * h(fp - f_xi))
         den = np.sum(h(fp - f_xi))
-        Zeff = num/den
+        Zeff = num / den
     else:
         raise TypeError("Effective impedance is only defined for long, xdip"
                         " and ydip impedance type.")
@@ -99,7 +108,14 @@ def effective_impedance(ring, imp, m, mu, sigma, M, tuneS, xi=None,
     return Zeff
 
 
-def head_tail_form_factor(ring, imp, m, sigma, tuneS, xi=None, mode="Hermite", mu=0):
+def head_tail_form_factor(ring,
+                          imp,
+                          m,
+                          sigma,
+                          tuneS,
+                          xi=None,
+                          mode="Hermite",
+                          mu=0):
     M = 1
     if not isinstance(imp, Impedance):
         raise TypeError("{} should be an Impedance object.".format(imp))
@@ -110,33 +126,33 @@ def head_tail_form_factor(ring, imp, m, sigma, tuneS, xi=None, mode="Hermite", m
         double_sided_impedance(imp)
 
     if mode in ["Hermite", "Legendre", "Sinusoidal", "Sacherer", "Chebyshev"]:
+
         def h(f):
-            return spectral_density(frequency=f, sigma=sigma, m=m,
-                                    mode=mode)
+            return spectral_density(frequency=f, sigma=sigma, m=m, mode=mode)
     else:
         raise NotImplementedError("Not implemanted yet.")
 
-    pmax = np.floor(fmax/(ring.f0 * M))
-    pmin = np.ceil(fmin/(ring.f0 * M))
+    pmax = np.floor(fmax / (ring.f0 * M))
+    pmin = np.ceil(fmin / (ring.f0 * M))
 
-    p = np.arange(pmin, pmax+1)
+    p = np.arange(pmin, pmax + 1)
 
     if imp.component_type == "long":
-        fp = ring.f0*(p*M + mu + m*tuneS)
+        fp = ring.f0 * (p*M + mu + m*tuneS)
         fp = fp[np.nonzero(fp)]  # Avoid division by 0
         den = np.sum(h(fp))
 
     elif imp.component_type == "xdip" or imp.component_type == "ydip":
         if imp.component_type == "xdip":
-            tuneXY = ring.tune[0]-np.floor(ring.tune[0])
+            tuneXY = ring.tune[0] - np.floor(ring.tune[0])
             if xi is None:
                 xi = ring.chro[0]
         elif imp.component_type == "ydip":
-            tuneXY = ring.tune[1]-np.floor(ring.tune[0])
+            tuneXY = ring.tune[1] - np.floor(ring.tune[0])
             if xi is None:
                 xi = ring.chro[1]
-        fp = ring.f0*(p*M + mu + tuneXY + m*tuneS)
-        f_xi = xi/ring.eta()*ring.f0
+        fp = ring.f0 * (p*M + mu + tuneXY + m*tuneS)
+        f_xi = xi / ring.eta() * ring.f0
         den = np.sum(h(fp - f_xi))
     else:
         raise TypeError("Effective impedance is only defined for long, xdip"
@@ -148,6 +164,7 @@ def head_tail_form_factor(ring, imp, m, sigma, tuneS, xi=None, mode="Hermite", m
 def tune_shift_from_effective_impedance(Zeff):
     pass
 
+
 def yokoya_elliptic(x_radius, y_radius):
     """
     Compute Yokoya factors for an elliptic beam pipe.
@@ -188,7 +205,7 @@ def yokoya_elliptic(x_radius, y_radius):
     yokoya_file = pd.read_csv(file)
     ratio_col = yokoya_file["x"]
     # compute semi-axes ratio (first column of this file)
-    ratio = (large_semiaxis - small_semiaxis)/(large_semiaxis + small_semiaxis)
+    ratio = (large_semiaxis-small_semiaxis) / (large_semiaxis+small_semiaxis)
 
     # interpolate Yokoya file at the correct ratio
     yoklong = 1
@@ -231,15 +248,15 @@ def beam_loss_factor(impedance, frequency, spectrum, ring):
     [1] : Handbook of accelerator physics and engineering, 3rd printing. 
         Eq (3) p239.
     """
-    pmax = np.floor(impedance.data.index.max()/ring.f0)
-    pmin = np.floor(impedance.data.index.min()/ring.f0)
+    pmax = np.floor(impedance.data.index.max() / ring.f0)
+    pmin = np.floor(impedance.data.index.min() / ring.f0)
 
     if pmin >= 0:
         double_sided_impedance(impedance)
-        pmin = -1*pmax
+        pmin = -1 * pmax
 
-    p = np.arange(pmin+1, pmax)
-    pf0 = p*ring.f0
+    p = np.arange(pmin + 1, pmax)
+    pf0 = p * ring.f0
     ReZ = np.real(impedance(pf0))
     spectral_density = np.abs(spectrum)**2
     # interpolation of the spectrum is needed to avoid problems liked to
@@ -247,7 +264,7 @@ def beam_loss_factor(impedance, frequency, spectrum, ring):
     # computing the spectrum directly to the frequency points gives
     # wrong results
     spect = interp1d(frequency, spectral_density)
-    kloss_beam = ring.f0 * np.sum(ReZ*spect(pf0))
+    kloss_beam = ring.f0 * np.sum(ReZ * spect(pf0))
 
     return kloss_beam
 
@@ -265,19 +282,19 @@ def double_sided_impedance(impedance):
     fmin = impedance.data.index.min()
 
     if fmin >= 0:
-        negative_index = impedance.data.index*-1
+        negative_index = impedance.data.index * -1
         negative_data = impedance.data.set_index(negative_index)
 
         imp_type = impedance.component_type
 
         if imp_type == "long":
-            negative_data["imag"] = -1*negative_data["imag"]
+            negative_data["imag"] = -1 * negative_data["imag"]
 
         elif (imp_type == "xdip") or (imp_type == "ydip"):
-            negative_data["real"] = -1*negative_data["real"]
+            negative_data["real"] = -1 * negative_data["real"]
 
         elif (imp_type == "xquad") or (imp_type == "yquad"):
-            negative_data["real"] = -1*negative_data["real"]
+            negative_data["real"] = -1 * negative_data["real"]
 
         else:
             raise ValueError("Wrong impedance type")
diff --git a/mbtrack2/utilities/optics.py b/mbtrack2/utilities/optics.py
index 979ed949a8bc73b5ca76db58acd0e783793ef9c8..806d52bc793e60c74c903a790ca01bdd7da1ee40 100644
--- a/mbtrack2/utilities/optics.py
+++ b/mbtrack2/utilities/optics.py
@@ -4,8 +4,8 @@ Module where the class to store the optic functions and the lattice physical
 parameters are defined.
 """
 
-import numpy as np
 import matplotlib.pyplot as plt
+import numpy as np
 from scipy.interpolate import interp1d
 
 
@@ -54,10 +54,13 @@ class Optics:
     plot(self, var, option, n_points=1000)
         Plot optical variables.
     """
-    
-    def __init__(self, lattice_file=None, local_beta=None, local_alpha=None, 
-                 local_dispersion=None, **kwargs):
-        
+    def __init__(self,
+                 lattice_file=None,
+                 local_beta=None,
+                 local_alpha=None,
+                 local_dispersion=None,
+                 **kwargs):
+
         if lattice_file is not None:
             self.use_local_values = False
             self.load_from_AT(lattice_file, **kwargs)
@@ -70,16 +73,16 @@ class Optics:
             else:
                 self._local_alpha = local_alpha
             if local_dispersion is None:
-                self.local_dispersion = np.zeros((4,))
+                self.local_dispersion = np.zeros((4, ))
             else:
                 self.local_dispersion = local_dispersion
-            self._local_gamma = (1 + self._local_alpha**2)/self._local_beta
-            
+            self._local_gamma = (1 + self._local_alpha**2) / self._local_beta
+
         else:
             self.use_local_values = True
             self._local_beta = local_beta
             self._local_alpha = local_alpha
-            self._local_gamma = (1 + self._local_alpha**2)/self._local_beta
+            self._local_gamma = (1 + self._local_alpha**2) / self._local_beta
             self.local_dispersion = local_dispersion
 
     def load_from_AT(self, lattice_file, **kwargs):
@@ -99,80 +102,91 @@ class Optics:
         import at
         self.n_points = int(kwargs.get("n_points", 1e3))
         periodicity = kwargs.get("periodicity")
-        
+
         self.lattice = at.load_lattice(lattice_file)
         if self.lattice.radiation:
             self.lattice.radiation_off()
         lattice = self.lattice.slice(slices=self.n_points)
         refpts = np.arange(0, len(lattice))
-        twiss0, tune, chrom, twiss = at.linopt(lattice, refpts=refpts,
-                                                  get_chrom=True)
-        
+        twiss0, tune, chrom, twiss = at.linopt(lattice,
+                                               refpts=refpts,
+                                               get_chrom=True)
+
         if periodicity is None:
             self.periodicity = lattice.periodicity
         else:
             self.periodicity = periodicity
-        
+
         if self.periodicity > 1:
-            for i in range(self.periodicity-1):
-                pos = np.append(twiss.s_pos, twiss.s_pos + twiss.s_pos[-1]*(i+1))
+            for i in range(self.periodicity - 1):
+                pos = np.append(twiss.s_pos,
+                                twiss.s_pos + twiss.s_pos[-1] * (i+1))
         else:
             pos = twiss.s_pos
-            
+
         self.position = pos
         self.beta_array = np.tile(twiss.beta.T, self.periodicity)
         self.alpha_array = np.tile(twiss.alpha.T, self.periodicity)
         self.dispersion_array = np.tile(twiss.dispersion.T, self.periodicity)
         self.mu_array = np.tile(twiss.mu.T, self.periodicity)
-        
+
         self.position = np.append(self.position, self.lattice.circumference)
-        self.beta_array = np.append(self.beta_array, self.beta_array[:,0:1],
+        self.beta_array = np.append(self.beta_array,
+                                    self.beta_array[:, 0:1],
                                     axis=1)
-        self.alpha_array = np.append(self.alpha_array, self.alpha_array[:,0:1],
+        self.alpha_array = np.append(self.alpha_array,
+                                     self.alpha_array[:, 0:1],
                                      axis=1)
         self.dispersion_array = np.append(self.dispersion_array,
-                                          self.dispersion_array[:,0:1], axis=1)
-        self.mu_array = np.append(self.mu_array,
-                                  self.mu_array[:,0:1], axis=1)
-        
-        self.gamma_array = (1 + self.alpha_array**2)/self.beta_array
+                                          self.dispersion_array[:, 0:1],
+                                          axis=1)
+        self.mu_array = np.append(self.mu_array, self.mu_array[:, 0:1], axis=1)
+
+        self.gamma_array = (1 + self.alpha_array**2) / self.beta_array
         self.tune = tune * self.periodicity
         self.chro = chrom * self.periodicity
         self.ac = at.get_mcf(self.lattice)
-        
-        self.mu_array[:,-1] = (np.floor(self.mu_array[:,-2]/(2*np.pi)) +
-                               self.tune)*2*np.pi
-        
+
+        self.mu_array[:, -1] = (np.floor(self.mu_array[:, -2] /
+                                         (2 * np.pi)) + self.tune) * 2 * np.pi
+
         self.setup_interpolation()
-        
-        
-    def setup_interpolation(self):      
+
+    def setup_interpolation(self):
         """Setup interpolation of the optic functions."""
-        self.betaX = interp1d(self.position, self.beta_array[0,:],
+        self.betaX = interp1d(self.position,
+                              self.beta_array[0, :],
                               kind='linear')
-        self.betaY = interp1d(self.position, self.beta_array[1,:],
+        self.betaY = interp1d(self.position,
+                              self.beta_array[1, :],
                               kind='linear')
-        self.alphaX = interp1d(self.position, self.alpha_array[0,:],
+        self.alphaX = interp1d(self.position,
+                               self.alpha_array[0, :],
                                kind='linear')
-        self.alphaY = interp1d(self.position, self.alpha_array[1,:],
+        self.alphaY = interp1d(self.position,
+                               self.alpha_array[1, :],
                                kind='linear')
-        self.gammaX = interp1d(self.position, self.gamma_array[0,:],
+        self.gammaX = interp1d(self.position,
+                               self.gamma_array[0, :],
                                kind='linear')
-        self.gammaY = interp1d(self.position, self.gamma_array[1,:],
+        self.gammaY = interp1d(self.position,
+                               self.gamma_array[1, :],
                                kind='linear')
-        self.dispX = interp1d(self.position, self.dispersion_array[0,:],
+        self.dispX = interp1d(self.position,
+                              self.dispersion_array[0, :],
                               kind='linear')
-        self.disppX = interp1d(self.position, self.dispersion_array[1,:],
+        self.disppX = interp1d(self.position,
+                               self.dispersion_array[1, :],
                                kind='linear')
-        self.dispY = interp1d(self.position, self.dispersion_array[2,:],
+        self.dispY = interp1d(self.position,
+                              self.dispersion_array[2, :],
                               kind='linear')
-        self.disppY = interp1d(self.position, self.dispersion_array[3,:],
+        self.disppY = interp1d(self.position,
+                               self.dispersion_array[3, :],
                                kind='linear')
-        self.muX = interp1d(self.position, self.mu_array[0,:],
-                              kind='linear')
-        self.muY = interp1d(self.position, self.mu_array[1,:],
-                              kind='linear')
-    
+        self.muX = interp1d(self.position, self.mu_array[0, :], kind='linear')
+        self.muY = interp1d(self.position, self.mu_array[1, :], kind='linear')
+
     @property
     def local_beta(self):
         """
@@ -180,7 +194,7 @@ class Optics:
 
         """
         return self._local_beta
-    
+
     @local_beta.setter
     def local_beta(self, beta_array):
         """
@@ -194,8 +208,8 @@ class Optics:
 
         """
         self._local_beta = beta_array
-        self._local_gamma = (1+self._local_alpha**2) / self._local_beta
-        
+        self._local_gamma = (1 + self._local_alpha**2) / self._local_beta
+
     @property
     def local_alpha(self):
         """
@@ -203,7 +217,7 @@ class Optics:
 
         """
         return self._local_alpha
-    
+
     @local_alpha.setter
     def local_alpha(self, alpha_array):
         """
@@ -217,8 +231,8 @@ class Optics:
 
         """
         self._local_alpha = alpha_array
-        self._local_gamma = (1+self._local_alpha**2) / self._local_beta
-    
+        self._local_gamma = (1 + self._local_alpha**2) / self._local_beta
+
     @property
     def local_gamma(self):
         """
@@ -226,7 +240,7 @@ class Optics:
 
         """
         return self._local_gamma
-    
+
     def beta(self, position):
         """
         Return beta functions at specific locations given by position. If no
@@ -247,7 +261,7 @@ class Optics:
         else:
             beta = [self.betaX(position), self.betaY(position)]
             return np.array(beta)
-    
+
     def alpha(self, position):
         """
         Return alpha functions at specific locations given by position. If no
@@ -268,7 +282,7 @@ class Optics:
         else:
             alpha = [self.alphaX(position), self.alphaY(position)]
             return np.array(alpha)
-    
+
     def gamma(self, position):
         """
         Return gamma functions at specific locations given by position. If no
@@ -289,7 +303,7 @@ class Optics:
         else:
             gamma = [self.gammaX(position), self.gammaY(position)]
             return np.array(gamma)
-    
+
     def dispersion(self, position):
         """
         Return dispersion functions at specific locations given by position. 
@@ -309,10 +323,14 @@ class Optics:
         if self.use_local_values:
             return np.outer(self.local_dispersion, np.ones_like(position))
         else:
-            dispersion = [self.dispX(position), self.disppX(position), 
-                          self.dispY(position), self.disppY(position)]
+            dispersion = [
+                self.dispX(position),
+                self.disppX(position),
+                self.dispY(position),
+                self.disppY(position)
+            ]
             return np.array(dispersion)
-        
+
     def mu(self, position):
         """
         Return phase advances at specific locations given by position. 
@@ -329,11 +347,11 @@ class Optics:
             Phase advances.
         """
         if self.use_local_values:
-            return np.outer(np.array([0,0]), np.ones_like(position))
+            return np.outer(np.array([0, 0]), np.ones_like(position))
         else:
             mu = [self.muX(position), self.muY(position)]
             return np.array(mu)
-        
+
     def plot(self, var, option, n_points=1000):
         """
         Plot optical variables.
@@ -351,49 +369,62 @@ class Optics:
             Number of points on the plot. The default value is 1000.
     
         """
-    
-        var_dict = {"beta":self.beta, "alpha":self.alpha, "gamma":self.gamma, 
-                    "dispersion":self.dispersion, "mu":self.mu}
-        
+
+        var_dict = {
+            "beta": self.beta,
+            "alpha": self.alpha,
+            "gamma": self.gamma,
+            "dispersion": self.dispersion,
+            "mu": self.mu
+        }
+
         if var == "dispersion":
-            option_dict = {"x":0, "px":1, "y":2, "py":3}
-            
+            option_dict = {"x": 0, "px": 1, "y": 2, "py": 3}
+
             label = ["D$_{x}$ (m)", "D'$_{x}$", "D$_{y}$ (m)", "D'$_{y}$"]
-            
+
             ylabel = label[option_dict[option]]
-         
-        
-        elif var=="beta" or var=="alpha" or var=="gamma" or var=="mu":
-            option_dict = {"x":0, "y":1}
-            label_dict = {"beta":"$\\beta$", "alpha":"$\\alpha$", 
-                          "gamma":"$\\gamma$", "mu":"$\\mu$"}
-            
+
+        elif var == "beta" or var == "alpha" or var == "gamma" or var == "mu":
+            option_dict = {"x": 0, "y": 1}
+            label_dict = {
+                "beta": "$\\beta$",
+                "alpha": "$\\alpha$",
+                "gamma": "$\\gamma$",
+                "mu": "$\\mu$"
+            }
+
             if option == "x": label_sup = "$_{x}$"
             elif option == "y": label_sup = "$_{y}$"
-            
-            unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)", "mu":""}
-            
+
+            unit = {
+                "beta": " (m)",
+                "alpha": "",
+                "gamma": " (m$^{-1}$)",
+                "mu": ""
+            }
+
             ylabel = label_dict[var] + label_sup + unit[var]
-  
-                
+
         else:
             raise ValueError("Variable name is not found.")
-        
+
         if self.use_local_values is not True:
-            position = np.linspace(0, self.lattice.circumference, int(n_points))
-        else: 
-            position = np.linspace(0,1)
-            
+            position = np.linspace(0, self.lattice.circumference,
+                                   int(n_points))
+        else:
+            position = np.linspace(0, 1)
+
         var_list = var_dict[var](position)[option_dict[option]]
         fig, ax = plt.subplots()
-        ax.plot(position,var_list)
-           
+        ax.plot(position, var_list)
+
         ax.set_xlabel("position (m)")
         ax.set_ylabel(ylabel)
-        
+
         return fig
 
-    
+
 class PhysicalModel:
     """
     Store the lattice physical parameters such as apperture and resistivity.
@@ -452,31 +483,40 @@ class PhysicalModel:
         Return the effective radius of the chamber for resistive wall 
         calculations.
     """
-    def __init__(self, ring, x_right, y_top, shape, rho, x_left=None, 
-                 y_bottom=None, n_points=1e4):
-        
+    def __init__(self,
+                 ring,
+                 x_right,
+                 y_top,
+                 shape,
+                 rho,
+                 x_left=None,
+                 y_bottom=None,
+                 n_points=1e4):
+
         self.n_points = int(n_points)
         self.position = np.linspace(0, ring.L, self.n_points)
-        self.x_right = np.ones_like(self.position)*x_right 
-        self.y_top = np.ones_like(self.position)*y_top
-        
+        self.x_right = np.ones_like(self.position) * x_right
+        self.y_top = np.ones_like(self.position) * y_top
+
         if x_left is None:
-            self.x_left = np.ones_like(self.position)*-1*x_right
+            self.x_left = np.ones_like(self.position) * -1 * x_right
         else:
-            self.x_left = np.ones_like(self.position)*x_left
-        
+            self.x_left = np.ones_like(self.position) * x_left
+
         if y_bottom is None:
-            self.y_bottom = np.ones_like(self.position)*-1*y_top
+            self.y_bottom = np.ones_like(self.position) * -1 * y_top
         else:
-            self.y_bottom = np.ones_like(self.position)*y_bottom
-        
+            self.y_bottom = np.ones_like(self.position) * y_bottom
+
         self.length = self.position[1:] - self.position[:-1]
-        self.center = (self.position[1:] + self.position[:-1])/2
-        self.rho = np.ones_like(self.center)*rho
-        
-        self.shape = np.repeat(np.array([shape]), self.n_points-1)
+        self.center = (self.position[1:] + self.position[:-1]) / 2
+        self.rho = np.ones_like(self.center) * rho
 
-    def resistive_wall_effective_radius(self, optics, x_right=True, 
+        self.shape = np.repeat(np.array([shape]), self.n_points - 1)
+
+    def resistive_wall_effective_radius(self,
+                                        optics,
+                                        x_right=True,
                                         y_top=True):
         """
         Return the effective radius of the chamber for resistive wall 
@@ -523,51 +563,75 @@ class PhysicalModel:
         Instruments and Methods in Physics Research Section A: Accelerators, 
         Spectrometers, Detectors and Associated Equipment 806 (2016): 221-230.         
         """
-        
+
         idx = self.rho != 0
-        
+
         if x_right is True:
-            a0 = (self.x_right[1:] + self.x_right[:-1])/2
+            a0 = (self.x_right[1:] + self.x_right[:-1]) / 2
         else:
-            a0 = np.abs((self.x_left[1:] + self.x_left[:-1])/2)
-            
+            a0 = np.abs((self.x_left[1:] + self.x_left[:-1]) / 2)
+
         if y_top is True:
-            b0 = (self.y_top[1:] + self.y_top[:-1])/2
+            b0 = (self.y_top[1:] + self.y_top[:-1]) / 2
         else:
-            b0 = np.abs((self.y_bottom[1:] + self.y_bottom[:-1])/2)
-            
+            b0 = np.abs((self.y_bottom[1:] + self.y_bottom[:-1]) / 2)
+
         a0 = a0[idx]
         b0 = b0[idx]
-        
+
         beta = optics.beta(self.center[idx])
         length = self.length[idx]
-        
+
         L = length.sum()
-        sigma = 1/self.rho[idx]
-        beta_H_star = 1/L*(length*beta[0,:]).sum()
-        beta_V_star = 1/L*(length*beta[1,:]).sum()
-        sigma_star = 1/L*(length*sigma).sum()
-        
-        a1_H = (((length*beta[0,:]/(np.sqrt(sigma)*(a0)**1)).sum())**(-1)*L*beta_H_star/np.sqrt(sigma_star))**(1/1)
-        a2_H = (((length*beta[0,:]/(np.sqrt(sigma)*(a0)**2)).sum())**(-1)*L*beta_H_star/np.sqrt(sigma_star))**(1/2)
+        sigma = 1 / self.rho[idx]
+        beta_H_star = 1 / L * (length * beta[0, :]).sum()
+        beta_V_star = 1 / L * (length * beta[1, :]).sum()
+        sigma_star = 1 / L * (length * sigma).sum()
 
-        a1_V = (((length*beta[1,:]/(np.sqrt(sigma)*(b0)**1)).sum())**(-1)*L*beta_V_star/np.sqrt(sigma_star))**(1/1)
-        a2_V = (((length*beta[1,:]/(np.sqrt(sigma)*(b0)**2)).sum())**(-1)*L*beta_V_star/np.sqrt(sigma_star))**(1/2)
-        
-        a3_H = (((length*beta[0,:]/(np.sqrt(sigma)*(a0)**3)).sum())**(-1)*L*beta_H_star/np.sqrt(sigma_star))**(1/3)
-        a4_H = (((length*beta[0,:]/(np.sqrt(sigma)*(a0)**4)).sum())**(-1)*L*beta_H_star/np.sqrt(sigma_star))**(1/4)
-        
-        a3_V = (((length*beta[1,:]/(np.sqrt(sigma)*(b0)**3)).sum())**(-1)*L*beta_V_star/np.sqrt(sigma_star))**(1/3)
-        a4_V = (((length*beta[1,:]/(np.sqrt(sigma)*(b0)**4)).sum())**(-1)*L*beta_V_star/np.sqrt(sigma_star))**(1/4)
-        
-        a1_L = min((a1_H,a1_V))
-        a2_L = min((a2_H,a2_V))
-        
-        return (L, 1/sigma_star, beta_H_star, beta_V_star, a1_L, a2_L, a3_H, a4_H, a3_V, a4_V)
-        
-    def change_values(self, start_position, end_position, x_right=None, 
-                      y_top=None, shape=None, rho=None, x_left=None, 
-                      y_bottom=None, sym=True):
+        a1_H = (((length * beta[0, :] /
+                  (np.sqrt(sigma) * (a0)**1)).sum())**(-1) * L * beta_H_star /
+                np.sqrt(sigma_star))**(1 / 1)
+        a2_H = (((length * beta[0, :] /
+                  (np.sqrt(sigma) * (a0)**2)).sum())**(-1) * L * beta_H_star /
+                np.sqrt(sigma_star))**(1 / 2)
+
+        a1_V = (((length * beta[1, :] /
+                  (np.sqrt(sigma) * (b0)**1)).sum())**(-1) * L * beta_V_star /
+                np.sqrt(sigma_star))**(1 / 1)
+        a2_V = (((length * beta[1, :] /
+                  (np.sqrt(sigma) * (b0)**2)).sum())**(-1) * L * beta_V_star /
+                np.sqrt(sigma_star))**(1 / 2)
+
+        a3_H = (((length * beta[0, :] /
+                  (np.sqrt(sigma) * (a0)**3)).sum())**(-1) * L * beta_H_star /
+                np.sqrt(sigma_star))**(1 / 3)
+        a4_H = (((length * beta[0, :] /
+                  (np.sqrt(sigma) * (a0)**4)).sum())**(-1) * L * beta_H_star /
+                np.sqrt(sigma_star))**(1 / 4)
+
+        a3_V = (((length * beta[1, :] /
+                  (np.sqrt(sigma) * (b0)**3)).sum())**(-1) * L * beta_V_star /
+                np.sqrt(sigma_star))**(1 / 3)
+        a4_V = (((length * beta[1, :] /
+                  (np.sqrt(sigma) * (b0)**4)).sum())**(-1) * L * beta_V_star /
+                np.sqrt(sigma_star))**(1 / 4)
+
+        a1_L = min((a1_H, a1_V))
+        a2_L = min((a2_H, a2_V))
+
+        return (L, 1 / sigma_star, beta_H_star, beta_V_star, a1_L, a2_L, a3_H,
+                a4_H, a3_V, a4_V)
+
+    def change_values(self,
+                      start_position,
+                      end_position,
+                      x_right=None,
+                      y_top=None,
+                      shape=None,
+                      rho=None,
+                      x_left=None,
+                      y_bottom=None,
+                      sym=True):
         """
         Change the physical parameters between start_position and end_position.
 
@@ -598,23 +662,33 @@ class PhysicalModel:
         if x_left is not None:
             self.x_left[ind] = x_left
         elif (x_right is not None) and (sym is True):
-            self.x_left[ind] = -1*x_right
+            self.x_left[ind] = -1 * x_right
         if y_bottom is not None:
             self.y_bottom[ind] = y_bottom
         elif (y_top is not None) and (sym is True):
-            self.y_bottom[ind] = -1*y_top
-        
-        ind2 = ((self.position[:-1] > start_position) & 
+            self.y_bottom[ind] = -1 * y_top
+
+        ind2 = ((self.position[:-1] > start_position) &
                 (self.position[1:] < end_position))
         if rho is not None:
             self.rho[ind2] = rho
         if shape is not None:
             self.shape[ind2] = shape
-        
-    def taper(self, start_position, end_position, x_right_start=None, 
-              x_right_end=None, y_top_start=None, y_top_end=None, shape=None, 
-              rho=None, x_left_start=None, x_left_end=None, 
-              y_bottom_start=None, y_bottom_end=None, sym=True):
+
+    def taper(self,
+              start_position,
+              end_position,
+              x_right_start=None,
+              x_right_end=None,
+              y_top_start=None,
+              y_top_end=None,
+              shape=None,
+              rho=None,
+              x_left_start=None,
+              x_left_end=None,
+              y_bottom_start=None,
+              y_bottom_end=None,
+              sym=True):
         """
         Change the physical parameters to have a tapered transition between 
         start_position and end_position.
@@ -648,60 +722,61 @@ class PhysicalModel:
         """
         ind = (self.position > start_position) & (self.position < end_position)
         if (x_right_start is not None) and (x_right_end is not None):
-            self.x_right[ind] = np.linspace(x_right_start, x_right_end, sum(ind))
+            self.x_right[ind] = np.linspace(x_right_start, x_right_end,
+                                            sum(ind))
             if sym is True:
-                self.x_left[ind] = -1*np.linspace(x_right_start, x_right_end, sum(ind))
-                
+                self.x_left[ind] = -1 * np.linspace(x_right_start, x_right_end,
+                                                    sum(ind))
+
         if (y_top_start is not None) and (y_top_end is not None):
             self.y_top[ind] = np.linspace(y_top_start, y_top_end, sum(ind))
             if sym is True:
-                self.y_bottom[ind] = -1*np.linspace(y_top_start, y_top_end, sum(ind))
-            
+                self.y_bottom[ind] = -1 * np.linspace(y_top_start, y_top_end,
+                                                      sum(ind))
+
         if (x_left_start is not None) and (x_left_end is not None):
             self.x_left[ind] = np.linspace(x_left_start, x_left_end, sum(ind))
         if (y_bottom_start is not None) and (y_bottom_end is not None):
-            self.y_bottom[ind] = np.linspace(y_bottom_start, y_bottom_end, sum(ind))
-        
-        ind2 = ((self.position[:-1] > start_position) & 
+            self.y_bottom[ind] = np.linspace(y_bottom_start, y_bottom_end,
+                                             sum(ind))
+
+        ind2 = ((self.position[:-1] > start_position) &
                 (self.position[1:] < end_position))
         if rho is not None:
             self.rho[ind2] = rho
         if shape is not None:
             self.shape[ind2] = shape
-        
+
     def plot_aperture(self):
         """Plot horizontal and vertical apertures."""
         fig, axs = plt.subplots(2)
-        axs[0].plot(self.position,self.x_right*1e3)
-        axs[0].plot(self.position,self.x_left*1e3)
-        axs[0].set(xlabel="Longitudinal position [m]", 
+        axs[0].plot(self.position, self.x_right * 1e3)
+        axs[0].plot(self.position, self.x_left * 1e3)
+        axs[0].set(xlabel="Longitudinal position [m]",
                    ylabel="Horizontal aperture [mm]")
-        axs[0].legend(["Right","Left"])
-        
-        axs[1].plot(self.position,self.y_top*1e3)
-        axs[1].plot(self.position,self.y_bottom*1e3)
-        axs[1].set(xlabel="Longitudinal position [m]", 
+        axs[0].legend(["Right", "Left"])
+
+        axs[1].plot(self.position, self.y_top * 1e3)
+        axs[1].plot(self.position, self.y_bottom * 1e3)
+        axs[1].set(xlabel="Longitudinal position [m]",
                    ylabel="Vertical aperture [mm]")
-        axs[1].legend(["Top","Bottom"])
-        
+        axs[1].legend(["Top", "Bottom"])
+
         return (fig, axs)
-    
+
     def get_aperture(self, s):
         self.xp = interp1d(self.position, self.x_right, kind='linear')
         self.xm = interp1d(self.position, self.x_left, kind='linear')
         self.yp = interp1d(self.position, self.y_top, kind='linear')
         self.ym = interp1d(self.position, self.y_bottom, kind='linear')
-        aperture = np.array([self.xp(s),
-                             self.xm(s),
-                             self.yp(s),
-                             self.ym(s)])
+        aperture = np.array([self.xp(s), self.xm(s), self.yp(s), self.ym(s)])
         return aperture
-    
+
     def plot_resistivity(self):
         """Plot resistivity along the ring."""
         fig, ax = plt.subplots(1)
         ax.plot(self.position[1:], self.rho)
-        ax.set(xlabel="Longitudinal position [m]", 
-                   ylabel="Resistivity [ohm.m]")
-        
-        return (fig, ax)
\ No newline at end of file
+        ax.set(xlabel="Longitudinal position [m]",
+               ylabel="Resistivity [ohm.m]")
+
+        return (fig, ax)
diff --git a/mbtrack2/utilities/read_impedance.py b/mbtrack2/utilities/read_impedance.py
index a0550fee0c07f8e4c9e5b07fe27812afde017b2d..60367a5ba75952bbfa6085c70dd068b8bef99b5f 100644
--- a/mbtrack2/utilities/read_impedance.py
+++ b/mbtrack2/utilities/read_impedance.py
@@ -5,13 +5,15 @@ defined.
 """
 
 import os
-import pandas as pd
-import numpy as np
+import re
 from pathlib import Path
 from tempfile import NamedTemporaryFile
+
+import numpy as np
+import pandas as pd
 from scipy.constants import c
-import re
-from mbtrack2.impedance.wakefield import Impedance, WakeFunction, WakeField
+
+from mbtrack2.impedance.wakefield import Impedance, WakeField, WakeFunction
 
 
 def read_CST(file, component_type='long', divide_by=None, imp=True):
@@ -40,25 +42,31 @@ def read_CST(file, component_type='long', divide_by=None, imp=True):
         Data from file.
     """
     if imp:
-        df = pd.read_csv(file, comment="#", header=None, sep="\t",
+        df = pd.read_csv(file,
+                         comment="#",
+                         header=None,
+                         sep="\t",
                          names=["Frequency", "Real", "Imaginary"])
-        df["Frequency"] = df["Frequency"]*1e9
+        df["Frequency"] = df["Frequency"] * 1e9
         if divide_by is not None:
-            df["Real"] = df["Real"]/divide_by
-            df["Imaginary"] = df["Imaginary"]/divide_by
+            df["Real"] = df["Real"] / divide_by
+            df["Imaginary"] = df["Imaginary"] / divide_by
         if component_type == "long":
             df["Real"] = np.abs(df["Real"])
         df.set_index("Frequency", inplace=True)
         result = Impedance(variable=df.index,
-                           function=df["Real"] + 1j*df["Imaginary"],
+                           function=df["Real"] + 1j * df["Imaginary"],
                            component_type=component_type)
     else:
-        df = pd.read_csv(file, comment="#", header=None, sep="\t",
+        df = pd.read_csv(file,
+                         comment="#",
+                         header=None,
+                         sep="\t",
                          names=["Distance", "Wake"])
-        df["Time"] = df["Distance"]*1e-3/c
-        df["Wake"] = df["Wake"]*1e12
+        df["Time"] = df["Distance"] * 1e-3 / c
+        df["Wake"] = df["Wake"] * 1e12
         if divide_by is not None:
-            df["Wake"] = df["Wake"]/divide_by
+            df["Wake"] = df["Wake"] / divide_by
         df.set_index("Time", inplace=True)
         result = WakeFunction(variable=df.index,
                               function=df["Wake"],
@@ -86,17 +94,23 @@ def read_IW2D(file, file_type='Zlong', output=False):
         Data from file.
     """
     if file_type[0] == "Z":
-        df = pd.read_csv(file, delim_whitespace=True, header=None,
-                         names=["Frequency", "Real", "Imaginary"], skiprows=1)
+        df = pd.read_csv(file,
+                         delim_whitespace=True,
+                         header=None,
+                         names=["Frequency", "Real", "Imaginary"],
+                         skiprows=1)
         df.set_index("Frequency", inplace=True)
         df = df[df["Real"].notna()]
         df = df[df["Imaginary"].notna()]
         result = Impedance(variable=df.index,
-                           function=df["Real"] + 1j*df["Imaginary"],
+                           function=df["Real"] + 1j * df["Imaginary"],
                            component_type=file_type[1:])
     elif file_type[0] == "W":
-        df = pd.read_csv(file, delim_whitespace=True, header=None,
-                         names=["Distance", "Wake"], skiprows=1)
+        df = pd.read_csv(file,
+                         delim_whitespace=True,
+                         header=None,
+                         names=["Distance", "Wake"],
+                         skiprows=1)
         df["Time"] = df["Distance"] / c
         df.set_index("Time", inplace=True)
         if np.any(df.isna()):
@@ -138,8 +152,7 @@ def read_IW2D_folder(folder, suffix, select="WZ", output=False):
         different files.
     """
     if (select == "WZ") or (select == "ZW"):
-        types = {"W": WakeFunction,
-                 "Z": Impedance}
+        types = {"W": WakeFunction, "Z": Impedance}
     elif (select == "W"):
         types = {"W": WakeFunction}
     elif (select == "Z"):
@@ -154,8 +167,10 @@ def read_IW2D_folder(folder, suffix, select="WZ", output=False):
     list_for_wakefield = []
     for key, item in types.items():
         for component in components:
-            name = data_folder / (key + component + suffix)
-            res = read_IW2D(file=name, file_type=key + component, output=output)
+            name = data_folder / (key+component+suffix)
+            res = read_IW2D(file=name,
+                            file_type=key + component,
+                            output=output)
             list_for_wakefield.append(res)
 
     wake = WakeField(list_for_wakefield)
@@ -199,21 +214,24 @@ def read_ABCI(file, azimuthal=False, output=False):
 
     def _read_temp(file, file_type, file2=None):
         if file_type[0] == "Z":
-            df = pd.read_csv(file, delim_whitespace=True,
+            df = pd.read_csv(file,
+                             delim_whitespace=True,
                              names=["Frequency", "Real"])
-            df["Real"] = df["Real"]*1e3
-            df["Frequency"] = df["Frequency"]*1e9
-            df2 = pd.read_csv(file2, delim_whitespace=True,
+            df["Real"] = df["Real"] * 1e3
+            df["Frequency"] = df["Frequency"] * 1e9
+            df2 = pd.read_csv(file2,
+                              delim_whitespace=True,
                               names=["Frequency", "Imaginary"])
-            df2["Imaginary"] = df2["Imaginary"]*1e3
-            df2["Frequency"] = df2["Frequency"]*1e9
+            df2["Imaginary"] = df2["Imaginary"] * 1e3
+            df2["Frequency"] = df2["Frequency"] * 1e9
             df.set_index("Frequency", inplace=True)
             df2.set_index("Frequency", inplace=True)
             result = Impedance(variable=df.index,
-                               function=df["Real"] + 1j*df2["Imaginary"],
+                               function=df["Real"] + 1j * df2["Imaginary"],
                                component_type=file_type[1:])
         elif file_type[0] == "W":
-            df = pd.read_csv(file, delim_whitespace=True,
+            df = pd.read_csv(file,
+                             delim_whitespace=True,
                              names=["Time", "Wake"])
             df["Time"] = df["Time"] / c
             df["Wake"] = df["Wake"] * 1e12
@@ -225,12 +243,20 @@ def read_ABCI(file, azimuthal=False, output=False):
                                   component_type=file_type[1:])
         return result
 
-    abci_dict = {'  TITLE: LONGITUDINAL WAKE POTENTIAL             \n': 'Wlong',
-                 '  TITLE: REAL PART OF LONGITUDINAL IMPEDANCE                      \n': 'Zlong_re',
-                 '  TITLE: IMAGINARY PART OF LONGITUDINAL IMPEDANCE                 \n': 'Zlong_im',
-                 f'  TITLE: {source} WAKE POTENTIAL               \n': 'Wxdip',
-                 f'  TITLE: REAL PART OF {source} IMPEDANCE                        \n': 'Zxdip_re',
-                 f'  TITLE: IMAGINARY PART OF {source} IMPEDANCE                   \n': 'Zxdip_im'}
+    abci_dict = {
+        '  TITLE: LONGITUDINAL WAKE POTENTIAL             \n':
+        'Wlong',
+        '  TITLE: REAL PART OF LONGITUDINAL IMPEDANCE                      \n':
+        'Zlong_re',
+        '  TITLE: IMAGINARY PART OF LONGITUDINAL IMPEDANCE                 \n':
+        'Zlong_im',
+        f'  TITLE: {source} WAKE POTENTIAL               \n':
+        'Wxdip',
+        f'  TITLE: REAL PART OF {source} IMPEDANCE                        \n':
+        'Zxdip_re',
+        f'  TITLE: IMAGINARY PART OF {source} IMPEDANCE                   \n':
+        'Zxdip_im'
+    }
 
     wake_list = []
     start = True
@@ -269,10 +295,12 @@ def read_ABCI(file, azimuthal=False, output=False):
                     tmp.writelines(body)
                     tmp.flush()
                     tmp.close()
-                    if abci_dict[header[0]][1:] == "long" and impedance_type == 0:
+                    if abci_dict[
+                            header[0]][1:] == "long" and impedance_type == 0:
                         comp = _read_temp(tmp.name, abci_dict[header[0]])
                         wake_list.append(comp)
-                    elif abci_dict[header[0]][1:] == "long" and impedance_type == 1:
+                    elif abci_dict[
+                            header[0]][1:] == "long" and impedance_type == 1:
                         pass
                         # Wlongxdip & Wlongydip not implemented yet
                         # comp = _read_temp(tmp.name, abci_dict[header[0]]+'dip')
@@ -283,20 +311,24 @@ def read_ABCI(file, azimuthal=False, output=False):
                         wake_list.append(comp_x)
                         wake_list.append(comp_y)
                     os.unlink(tmp.name)
-                elif (abci_dict[header[0]][0] == "Z") and (abci_dict[header[0]][-2:] == "re"):
+                elif (abci_dict[header[0]][0]
+                      == "Z") and (abci_dict[header[0]][-2:] == "re"):
                     tmp1 = NamedTemporaryFile(delete=False, mode="w+")
                     tmp1.writelines(body)
                     tmp1.flush()
                     tmp1.close()
-                elif (abci_dict[header[0]][0] == "Z") and (abci_dict[header[0]][-2:] == "im"):
+                elif (abci_dict[header[0]][0]
+                      == "Z") and (abci_dict[header[0]][-2:] == "im"):
                     tmp2 = NamedTemporaryFile(delete=False, mode="w+")
                     tmp2.writelines(body)
                     tmp2.flush()
                     tmp2.close()
-                    if abci_dict[header[0]][1:-3] == "long" and impedance_type == 0:
+                    if abci_dict[
+                            header[0]][1:-3] == "long" and impedance_type == 0:
                         comp = _read_temp(tmp1.name, "Zlong", tmp2.name)
                         wake_list.append(comp)
-                    elif abci_dict[header[0]][1:-3] == "long" and impedance_type == 1:
+                    elif abci_dict[
+                            header[0]][1:-3] == "long" and impedance_type == 1:
                         pass
                         # Zlongxdip & Zlongydip not implemented yet
                         # comp = _read_temp(tmp1.name, "Zlongdip", tmp2.name)
@@ -335,12 +367,14 @@ def read_ECHO2D(file, component_type='long'):
         Data from file.
     """
 
-    df = pd.read_csv(file, delim_whitespace=True,
-                     header=None, names=["Distance", "Wake"])
-    df["Time"] = df["Distance"]/100/c
-    df["Wake"] = df["Wake"]*1e12
+    df = pd.read_csv(file,
+                     delim_whitespace=True,
+                     header=None,
+                     names=["Distance", "Wake"])
+    df["Time"] = df["Distance"] / 100 / c
+    df["Wake"] = df["Wake"] * 1e12
     if component_type != 'long':
-        df["Wake"] = df["Wake"]*-1
+        df["Wake"] = df["Wake"] * -1
     df.set_index("Time", inplace=True)
     result = WakeFunction(variable=df.index,
                           function=df["Wake"],
diff --git a/mbtrack2/utilities/spectrum.py b/mbtrack2/utilities/spectrum.py
index df78d64110cf84b337ef9b29bb6d043e835512d5..7f8fb88532443286c8ff6688c8c2740fb7df0a1f 100644
--- a/mbtrack2/utilities/spectrum.py
+++ b/mbtrack2/utilities/spectrum.py
@@ -41,17 +41,19 @@ def spectral_density(frequency, sigma, m=1, k=0, mode="Hermite"):
     """
 
     if mode == "Hermite":
-        return 1/(np.math.factorial(m)*2**m)*(2*np.pi*frequency*sigma)**(2*m)*np.exp(
-            -(2*np.pi*frequency*sigma)**2)
+        return 1 / (np.math.factorial(m) *
+                    2**m) * (2 * np.pi * frequency * sigma)**(
+                        2 * m) * np.exp(-(2 * np.pi * frequency * sigma)**2)
     elif mode == "Chebyshev":
-        tau_l = 4*sigma
-        return (jv(m, 2*np.pi*frequency*tau_l))**2
+        tau_l = 4 * sigma
+        return (jv(m, 2 * np.pi * frequency * tau_l))**2
     elif mode == "Legendre":
-        tau_l = 4*sigma
-        return (spherical_jn(m,  np.abs(2*np.pi*frequency*tau_l)))**2
+        tau_l = 4 * sigma
+        return (spherical_jn(m, np.abs(2 * np.pi * frequency * tau_l)))**2
     elif mode == "Sacherer" or mode == "Sinusoidal":
-        y = 4*2*np.pi*frequency*sigma/np.pi
-        return (2*(m+1)/np.pi*1/np.abs(y**2-(m+1)**2)*np.sqrt(1+(-1)**m*np.cos(np.pi*y)))**2
+        y = 4 * 2 * np.pi * frequency * sigma / np.pi
+        return (2 * (m+1) / np.pi * 1 / np.abs(y**2 - (m + 1)**2) *
+                np.sqrt(1 + (-1)**m * np.cos(np.pi * y)))**2
     else:
         raise NotImplementedError("Not implemanted yet.")
 
@@ -77,7 +79,7 @@ def gaussian_bunch_spectrum(frequency, sigma):
     [1] : Gamelin, A. (2018). Collective effects in a transient microbunching 
     regime and ion cloud mitigation in ThomX. p86, Eq. 4.19
     """
-    return np.exp(-1/2*(2*np.pi*frequency)**2*sigma**2)
+    return np.exp(-1 / 2 * (2 * np.pi * frequency)**2 * sigma**2)
 
 
 def gaussian_bunch(time, sigma):
@@ -96,10 +98,13 @@ def gaussian_bunch(time, sigma):
     bunch_profile : array
         Bunch profile in [s**-1] sampled at points given in time.
     """
-    return np.exp(-1/2*(time**2/sigma**2))/(sigma*np.sqrt(2*np.pi))
+    return np.exp(-1 / 2 * (time**2 / sigma**2)) / (sigma * np.sqrt(2 * np.pi))
 
 
-def beam_spectrum(frequency, M, bunch_spacing, sigma=None,
+def beam_spectrum(frequency,
+                  M,
+                  bunch_spacing,
+                  sigma=None,
                   bunch_spectrum=None):
     """
     Compute the beam spectrum assuming constant spacing between bunches [1].
@@ -131,9 +136,9 @@ def beam_spectrum(frequency, M, bunch_spacing, sigma=None,
     if bunch_spectrum is None:
         bunch_spectrum = gaussian_bunch_spectrum(frequency, sigma)
 
-    beam_spectrum = (bunch_spectrum * np.exp(1j*np.pi*frequency *
-                                             bunch_spacing*(M-1)) *
-                     np.sin(M*np.pi*frequency*bunch_spacing) /
-                     np.sin(np.pi*frequency*bunch_spacing))
+    beam_spectrum = (bunch_spectrum *
+                     np.exp(1j * np.pi * frequency * bunch_spacing * (M-1)) *
+                     np.sin(M * np.pi * frequency * bunch_spacing) /
+                     np.sin(np.pi * frequency * bunch_spacing))
 
     return beam_spectrum