diff --git a/tracking/monitors/__init__.py b/tracking/monitors/__init__.py
index adc490ba661610646c8ad5831362e407daf7984f..3ff31947a254f22aecfdd614e555637ef5b6462a 100644
--- a/tracking/monitors/__init__.py
+++ b/tracking/monitors/__init__.py
@@ -18,4 +18,5 @@ from mbtrack2.tracking.monitors.plotting import (plot_bunchdata,
                                                  plot_beamdata,
                                                  plot_wakedata,
                                                  plot_tunedata,
-                                                 plot_cavitydata)
\ No newline at end of file
+                                                 plot_cavitydata,
+                                                 streak_beamdata)
\ No newline at end of file
diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index 780dc4875b6fb2074f1132150677d197f016f7a4..a2c2f3b2273eb49601c3a212ad947b4338ec3929 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -3,8 +3,8 @@
 Module for plotting the data recorded by the monitor module during the 
 tracking.
 
-@author: Watanyu Foosang
-@Date: 10/04/2020
+@author: Watanyu Foosang, Alexis Gamelin
+@Date: 20/07/2021
 """
 
 import numpy as np
@@ -15,15 +15,15 @@ import h5py as hp
 import random
 from scipy.fft import rfftfreq
 
-def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean", 
-                  x_var="time", turn=None, plot_type="normal", cm_lim=None):
+def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean", 
+                  x_var="time", turn=None, legend=None):
     """
-    Plot data recorded by BeamMonitor.
+    Plot 2D data recorded by BeamMonitor.
 
     Parameters
     ----------
-    filename : str
-        Name of the HDF5 file that contains the data.
+    filenames : str or list of str
+        Names of the HDF5 files to be plotted.
     dataset : {"current","emit","mean","std"}
         HDF5 file's dataset to be plotted. The default is "mean".
     dimension : str
@@ -35,18 +35,14 @@ def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean",
     stat_var : {"mean", "std"}
         Statistical value of the dimension.
     x_var : {"time", "index"}
-        Choice of the horizontal axis for "normal" plot_type.
-        "time" corresponds to turn number.
-        "index" corresponds to bunch index.
+        Choice of the horizontal axis:
+            "time" corresponds to turn number.
+            "index" corresponds to bunch index.
     turn : int or float, optional
-        Choice of the turn to plot for the "normal" plot_type with "index".
+        Choice of the turn to plot when x_var = "index".
         If None, the last turn is plotted.
-    plot_type : {"normal","streak"}, optional
-        Choice of the type of plot. 
-        "normal" is for the standard line 2D plot.
-        "streak" is for the 3D like image.
-    cm_lim : list [vmin, vmax], optional
-        Colormap scale for the "streak" plot.
+    legend : list of str, optional
+        Legend to add for each file.
 
     Return
     ------
@@ -54,17 +50,24 @@ def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean",
         Figure object with the plot on it.
 
     """
-    file = hp.File(filename, "r")
-    data = file["Beam"]
-    time = np.array(data["time"])
     
-    if plot_type == "normal":
+    if isinstance(filenames, str):
+        filenames = [filenames]
     
+    fig, ax = plt.subplots()
+    
+    for filename in filenames:
+        file = hp.File(filename, "r")
+        data = file["Beam"]
+        time = np.array(data["time"])
+            
         if x_var == "time":
             x = time
             x_label = "Number of turns"
+            bunch_index = data["current"][:,0] != 0
+            
             if dataset == "current":
-                y = np.nansum(data["current"][:],0)*1e3
+                y = np.nansum(data[dataset][bunch_index,:],0)*1e3
                 y_label = "Total current (mA)"
             elif dataset == "emit":
                 dimension_dict = {"x":0, "y":1, "s":2}
@@ -73,9 +76,9 @@ def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean",
                          "$\\epsilon_{y}$ (m.rad)",
                          "$\\epsilon_{s}$ (m.rad)"]
                 if stat_var == "mean":
-                    y = np.nanmean(data["emit"][axis,:],0)
+                    y = np.nanmean(data[dataset][axis,bunch_index,:],0)
                 elif stat_var == "std":
-                    y = np.nanstd(data["emit"][axis,:],0)
+                    y = np.nanstd(data[dataset][axis,bunch_index,:],0)
                 y_label = stat_var + " " + label[axis]
             elif dataset == "mean" or dataset == "std":
                 dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
@@ -85,9 +88,9 @@ def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean",
                 label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
                          "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
                 if stat_var == "mean":   
-                    y = np.nanmean(data[dataset][axis,:],0)*scale[axis]
+                    y = np.nanmean(data[dataset][axis,bunch_index,:],0)*scale[axis]
                 elif stat_var == "std":      
-                    y = np.nanstd(data[dataset][axis,:],0)*scale[axis]
+                    y = np.nanstd(data[dataset][axis,bunch_index,:],0)*scale[axis]
                 label_sup = {"mean":"mean of ", "std":"std of "}
                 y_label = label_sup[stat_var] + dataset + " " + label[axis]
                 
@@ -122,60 +125,109 @@ def plot_beamdata(filename, dataset="mean", dimension="tau", stat_var="mean",
         else:
             raise ValueError("x_var should be time or index")
         
-        fig, ax = plt.subplots()
         ax.plot(x, y)
         ax.set_xlabel(x_label)
         ax.set_ylabel(y_label)
+        if legend is not None:
+            plt.legend(legend)
+            
+        file.close()
         
-    elif plot_type == "streak":
-        h = len(data["mean"][0,:,0])
-        x = np.arange(h)
-        x_label = "Bunch index"
-        y = time
-        y_label = "Number of turns"
-        if dataset == "current":
-            z = np.array(data["current"]*1e3).T
-            z_label = "Bunch current (mA)"
-        elif dataset == "emit":
-            dimension_dict = {"x":0, "y":1, "s":2}
-            axis = dimension_dict[dimension]
-            label = ["$\\epsilon_{x}$ (m.rad)",
-                     "$\\epsilon_{y}$ (m.rad)",
-                     "$\\epsilon_{s}$ (m.rad)"]
-            z = np.array(data["emit"][axis,:,:]).T
-            z_label = label[axis]
-        else:
-            dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
-                                  "delta":5}
-            axis = dimension_dict[dimension]
-            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-            label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
-                         "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
-            z = np.array(data[dataset][axis,:,:]).T*scale[axis]
-            z_label = label[axis]
+    return fig
+            
+def streak_beamdata(filename, dataset="mean", dimension="tau", cm_lim=None):
+    """
+    Plot 3D data recorded by BeamMonitor.
+
+    Parameters
+    ----------
+    filename : str
+        Name of the HDF5 file that contains the data.
+    dataset : {"current","emit","mean","std"}
+        HDF5 file's dataset to be plotted. The default is "mean".
+    dimension : str
+         The dimension of the dataset to plot:
+            for "emit", dimension = {"x","y","s"},
+            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
+            not used if "current".
+        The default is "tau".
+    cm_lim : list [vmin, vmax], optional
+        Colormap scale for the "streak" plot.
+
+    Return
+    ------
+    fig : Figure
+        Figure object with the plot on it.
+
+    """
+    
+    file = hp.File(filename, "r")
+    data = file["Beam"]
+    time = np.array(data["time"])
         
-        fig, ax = plt.subplots()
-        ax.set_xlabel(x_label)
-        ax.set_ylabel(y_label)
-        cmap = mpl.cm.cool
-        c = ax.imshow(z, cmap=cmap, origin='lower' , aspect='auto',
-               extent=[x.min(), x.max(), y.min(), y.max()])
-        if cm_lim is not None:
-            c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
-        cbar = fig.colorbar(c, ax=ax)
-        cbar.set_label(z_label) 
+    h = len(data["mean"][0,:,0])
+    x = np.arange(h)
+    x_label = "Bunch index"
+    y = time
+    y_label = "Number of turns"
+    if dataset == "current":
+        z = (np.array(data["current"])*1e3).T
+        z_label = "Bunch current (mA)"
+        title = z_label
+    elif dataset == "emit":
+        dimension_dict = {"x":0, "y":1, "s":2}
+        axis = dimension_dict[dimension]
+        label = ["$\\epsilon_{x}$ (m.rad)",
+                 "$\\epsilon_{y}$ (m.rad)",
+                 "$\\epsilon_{s}$ (m.rad)"]
+        z = np.array(data["emit"][axis,:,:]).T
+        z_label = label[axis]
+        title = z_label
+    else:
+        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
+                              "delta":5}
+        axis = dimension_dict[dimension]
+        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
+        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
+                     "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
+        z = np.array(data[dataset][axis,:,:]).T*scale[axis]
+        z_label = label[axis]
+        if dataset == "mean":
+            title = label[axis] + " CM"
+        elif dataset == "std":
+            title = label[axis] + " RMS"
+            
+    fig, ax = plt.subplots()
+    ax.set_xlabel(x_label)
+    ax.set_ylabel(y_label)
+    ax.set_title(title)
+    
+    if dataset == "mean":
+        cmap = mpl.cm.coolwarm # diverging
+    else:
+        cmap = mpl.cm.inferno # sequential
+    
+    c = ax.imshow(z, cmap=cmap, origin='lower' , aspect='auto',
+            extent=[x.min(), x.max(), y.min(), y.max()])
+    if cm_lim is not None:
+        c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
+    cbar = fig.colorbar(c, ax=ax)
+    cbar.set_label(z_label)
+    
+    file.close()
         
     return fig
               
-def plot_bunchdata(filename, bunch_number, dataset, dimension="x", x_var="time"):
+def plot_bunchdata(filenames, bunch_number, dataset, dimension="x", 
+                   legend=None):
     """
     Plot data recorded by BunchMonitor.
     
     Parameters
     ----------
-    filename : str 
-        Name of the HDF5 file that contains the data.
-    bunch_number : int
+    filenames : str or list of str
+        Names of the HDF5 files to be plotted.
+    bunch_number : int or list of int
         Bunch to plot. This has to be identical to 'bunch_number' parameter in 
         'BunchMonitor' object.
     dataset : {"current", "emit", "mean", "std", "cs_invariant"}
@@ -186,8 +238,8 @@ def plot_bunchdata(filename, bunch_number, dataset, dimension="x", x_var="time")
             for "emit", dimension = {"x","y","s"},
             for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"},
             for "action", dimension = {"x","y"}.
-    x_var : {"time", "current"}, optional
-        Variable to be plotted on the horizontal axis. The default is "time".
+    legend : list of str, optional
+        Legend to add for each file.
         
     Return
     ------
@@ -196,60 +248,70 @@ def plot_bunchdata(filename, bunch_number, dataset, dimension="x", x_var="time")
 
     """
     
-    file = hp.File(filename, "r")
-    
-    group = "BunchData_{0}".format(bunch_number)  # Data group of the HDF5 file
-    
-    if dataset == "current":
-        y_var = file[group][dataset][:]*1e3
-        label = "current (mA)"
-        
-    elif dataset == "emit":
-        dimension_dict = {"x":0, "y":1, "s":2}
-                         
-        y_var = file[group][dataset][dimension_dict[dimension]]*1e9
-        
-        if dimension == "x": label = "hor. emittance (nm.rad)"
-        elif dimension == "y": label = "ver. emittance (nm.rad)"
-        elif dimension == "s": label = "long. emittance (nm.rad)"
+    if isinstance(filenames, str):
+        filenames = [filenames]
         
+    if isinstance(bunch_number, int):
+        ll = []
+        for i in range(len(filenames)):
+            ll.append(bunch_number)
+        bunch_number = ll
         
-    elif dataset == "mean" or dataset == "std":                        
-        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
-        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
-        axis_index = dimension_dict[dimension]
+    fig, ax = plt.subplots()
+    
+    for i, filename in enumerate(filenames):
+        file = hp.File(filename, "r")
+        group = "BunchData_{0}".format(bunch_number[i])  # Data group of the HDF5 file
         
-        y_var = file[group][dataset][axis_index]*scale[axis_index]
-        if dataset == "mean":
-            label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
-                          "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
-        else:
-            label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
-                          "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
-                          "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
-        
-        label = label_list[axis_index]
-        
-    elif dataset == "cs_invariant":
-        dimension_dict = {"x":0, "y":1}
-        axis_index = dimension_dict[dimension]
-        y_var = file[group][dataset][axis_index]
-        label_list = ['$J_x$ (m)', '$J_y$ (m)']
-        label = label_list[axis_index]
-        
-    if x_var == "current":
-        x_axis = file[group]["current"][:] * 1e3
-        xlabel = "current (mA)"
-    elif x_var == "time":
+        if dataset == "current":
+            y_var = file[group][dataset][:]*1e3
+            label = "current (mA)"
+            
+        elif dataset == "emit":
+            dimension_dict = {"x":0, "y":1, "s":2}
+                             
+            y_var = file[group][dataset][dimension_dict[dimension]]*1e9
+            
+            if dimension == "x": label = "hor. emittance (nm.rad)"
+            elif dimension == "y": label = "ver. emittance (nm.rad)"
+            elif dimension == "s": label = "long. emittance (nm.rad)"
+            
+            
+        elif dataset == "mean" or dataset == "std":                        
+            dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
+            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
+            axis_index = dimension_dict[dimension]
+            
+            y_var = file[group][dataset][axis_index]*scale[axis_index]
+            if dataset == "mean":
+                label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
+                              "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
+            else:
+                label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
+                              "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
+                              "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
+            
+            label = label_list[axis_index]
+            
+        elif dataset == "cs_invariant":
+            dimension_dict = {"x":0, "y":1}
+            axis_index = dimension_dict[dimension]
+            y_var = file[group][dataset][axis_index]
+            label_list = ['$J_x$ (m)', '$J_y$ (m)']
+            label = label_list[axis_index]
+
         x_axis = file[group]["time"][:]
-        xlabel = "number of turns"
+        xlabel = "Number of turns"
+            
+        fig, ax = plt.subplots()        
+        ax.plot(x_axis, y_var)
+        ax.set_xlabel(xlabel)
+        ax.set_ylabel(label)
+        if legend is not None:
+            plt.legend(legend)
+            
+        file.close()
         
-    fig, ax = plt.subplots()        
-    ax.plot(x_axis,y_var)
-    ax.set_xlabel(xlabel)
-    ax.set_ylabel(label)
-    
-    file.close()
     return fig
             
 def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,