diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index 0a9cad7e174ceb1d11d555b3355f5949d3d51ab0..f28b0df341fc7756b245550a80f8ff043aded3bd 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -32,6 +32,11 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
         needs to be specified. The default is None.
     x_var : str, optional
         Variable to be plotted on the horizontal axis. The default is "time".
+        
+    Return
+    ------
+    fig : Figure
+        Figure object with the plot on it.
 
     """
     
@@ -43,9 +48,10 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
         for i in range (len(path["time"])):
             total_current.append(np.sum(path["current"][:,i])*1e3)
             
-        plt.plot(path["time"],total_current)
-        plt.xlabel("Number of turns")
-        plt.ylabel("total current (mA)")
+        fig, ax = plt.subplots()
+        ax.plot(path["time"],total_current)
+        ax.set_xlabel("Number of turns")
+        ax.set_ylabel("total current (mA)")
         
     elif dataset == "emit":
         option_dict = {"x":0, "y":1, "s":2} #input option
@@ -60,75 +66,50 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
             for i in range (len(path["time"])):
                 mean_emit.append(np.mean(path["emit"][axis,:,i])*scale[axis])
             
-            plt.plot(path["time"],mean_emit)
-            label_sup = "avg. "
+            fig, ax = plt.subplots()
+            ax.plot(path["time"],mean_emit)
             
         elif stat_var == "std":
             std_emit = []
             for i in range (len(path["time"])):
                 std_emit.append(np.std(path["emit"][axis,:,i])*scale[axis])
-                
-            plt.plot(path["time"],std_emit)
-            label_sup = "std. "
+              
+            fig, ax = plt.subplots() 
+            ax.plot(path["time"],std_emit)
             
-        plt.xlabel("Number of turns")
-        plt.ylabel(label_sup + label[axis])
+        ax.set_xlabel("Number of turns")
+        ax.set_ylabel(stat_var+" " + label[axis])
         
-    elif dataset == "mean":
+    elif dataset == "mean" or dataset == "std":
         option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
         axis = option_dict[option]
         scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
         label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                       "$\\tau$ (ps)", "$\\delta$"]
         
+        fig, ax = plt.subplots()
         if stat_var == "mean":
-            mean_mean = []
+            mean_list = []
             for i in range (len(path["time"])):
-                mean_mean.append(np.mean(path["mean"][axis,:,i]*scale[axis]))
+                mean_list.append(np.mean(path[dataset][axis,:,i]*scale[axis]))
                 
-            plt.plot(path["time"],mean_mean)
-            label_sup = "avg. "
+            ax.plot(path["time"],mean_list)
+            label_sup = {"mean":"", "std":"std of "} # input stat_var
             
         elif stat_var == "std":
-            std_mean = []
+            std_list = []
             for i in range (len(path["time"])):
-                std_mean.append(np.std(path["mean"][axis,:,i]*scale[axis]))
+                std_list.append(np.std(path[dataset][axis,:,i]*scale[axis]))
             
-            plt.plot(path["time"],std_mean)
-            label_sup = "std. of avg. "
-            
-        plt.xlabel("Number of turns")
-        plt.ylabel(label_sup + label[axis])
-        
-    elif dataset == "std":
-        option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-        axis = option_dict[option]
-        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                      "$\\tau$ (ps)", "$\\delta$"]
-        
-        if stat_var == "mean":
-            mean_std = []
-            for i in range (len(path["time"])):
-                mean_std.append(np.mean(path["std"][axis,:,i]*scale[axis]))
-                
-            plt.plot(path["time"],mean_std)
-            label_sup = "std. "
+            ax.plot(path["time"],std_list)
+            label_sup = {"mean":"", "std":"std of "} #input stat_var
             
-        elif stat_var == "std":
-            std_std = []
-            for i in range (len(path["time"])):
-                std_std.append(np.std(path["std"][axis,:,i]*scale[axis]))
-            
-            plt.plot(path["time"],std_std)
-            label_sup = "std. of std. "
-            
-        plt.xlabel("Number of turns")
-        plt.ylabel(label_sup + label[axis])
+        ax.set_xlabel("Number of turns")
+        ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis])
     
     file.close()
-                        
-                  
+    return fig    
+              
 def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
     """
     Plot data recorded from a BunchMonitor.
@@ -149,6 +130,11 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
             for "mean" and "std", option = {"x","xp","y","yp","tau","delta"}    
     x_var : str, optional
         Variable to be plotted on the horizontal axis. The default is "time".
+        
+    Return
+    ------
+    fig : Figure
+        Figure object with the plot on it.
 
     """
     
@@ -161,43 +147,32 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
         label = "current (mA)"
         
     elif dataset == "emit":
-        # Specifying the axis
-        # horizontal axis (x)   -> 0 
-        # vertical axis (y)     -> 1 
-        # longitudinal axis (s) -> 2 
-                         
-        emit_axis = {"x":0, "y":1, "s":2}
+        option_dict = {"x":0, "y":1, "s":2}
                          
-        y_var = file[group][dataset][emit_axis[option]]*1e9
+        y_var = file[group][dataset][option_dict[option]]*1e9
         
         if option == "x": label = "hor. emittance (nm.rad)"
         elif option == "y": label = "ver. emittance (nm.rad)"
         elif option == "s": label = "long. emittance (nm.rad)"
         
         
-    elif dataset == "mean" or "std":
-        # Specify the variable
-        # x     -> 0 
-        # xp    -> 1 
-        # y     -> 2 
-        # yp    -> 3 
-        # tau   -> 4 
-        # delta -> 5 
-                           
-        var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+    elif dataset == "mean" or "std":                        
+        option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
         scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
         label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                       "$\\tau$ (ps)", "$\\delta$"]
         
-        y_var = file[group][dataset][var_dict[option]]*scale[var_dict[option]]
-        label = label_list[var_dict[option]]
+        y_var = file[group][dataset][option_dict[option]]\
+                *scale[option_dict[option]]
+        label = label_list[option_dict[option]]
         
-            
-    plt.plot(file[group]["time"][:],y_var)
-    plt.xlabel("number of turns")
-    plt.ylabel(label)
+    fig, ax = plt.subplots()        
+    ax.plot(file[group]["time"][:],y_var)
+    ax.set_xlabel("number of turns")
+    ax.set_ylabel(label)
     
     file.close()
+    return fig
             
 def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, 
                         only_alive=True):
@@ -220,6 +195,11 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
     only_alive : bool, optional
         When only_alive is True, only alive particles are plotted and dead 
         particles will be discarded.
+        
+    Return
+    ------
+    fig : Figure
+        Figure object with the plot on it.
     """
     
     file = hp.File(filename, "r")
@@ -227,15 +207,7 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
     group = "PhaseSpaceData_{0}".format(bunch_number)
     dataset = "particles"
 
-    # Specify the parameter
-    #     x     -> 0 
-    #     xp    -> 1 
-    #     y     -> 2 
-    #     yp    -> 3 
-    #     tau   -> 4 
-    #     delta -> 5  
-             
-    var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
     scale = [1e3,1e3,1e3,1e3,1e12,1]
     label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
              "$\\delta$"]
@@ -251,24 +223,20 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
 
     if only_alive is False:
         # format : sns.jointplot(x_axis, yaxis, kind)
-        x_axis = path[:,var_dict[x_var],turn_index[0][0]]
-        y_axis = path[:,var_dict[y_var],turn_index[0][0]]
+        x_axis = path[:,option_dict[x_var],turn_index[0][0]]
+        y_axis = path[:,option_dict[y_var],turn_index[0][0]]
 
     elif only_alive is True:
         alive_index = np.where(file[group]["alive"][:,turn_index])
 
-        x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]]
-        y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]]            
+        x_axis = path[alive_index[0],option_dict[x_var],turn_index[0][0]]
+        y_axis = path[alive_index[0],option_dict[y_var],turn_index[0][0]]    
         
-    sns.jointplot(x_axis*scale[var_dict[x_var]], 
-                  y_axis*scale[var_dict[y_var]], kind="kde")
+    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
+                        y_axis*scale[option_dict[y_var]], kind="kde")
     
-    plt.xlabel(label[var_dict[x_var]])
-    plt.ylabel(label[var_dict[y_var]])
+    plt.xlabel(label[option_dict[x_var]])
+    plt.ylabel(label[option_dict[y_var]])
             
     file.close()
-
-    
-    
-     
-    
+    return fig
diff --git a/tracking/optics.py b/tracking/optics.py
index be326c44431ae43988080dbc4571d4ce922de41b..da6ae20aec203e9828cf92d9066fda9f79345357 100644
--- a/tracking/optics.py
+++ b/tracking/optics.py
@@ -55,6 +55,8 @@ class Optics:
         Return gamma functions at specific locations given by position.
     dispersion(position)
         Return dispersion functions at specific locations given by position.
+    plot(self, var, option, n_points=1000)
+        Plot optical variables.
     """
     
     def __init__(self, lattice_file=None, local_beta=None, local_alpha=None, 
@@ -245,9 +247,9 @@ class Optics:
                           self.dispY(position), self.disppY(position)]
             return np.array(dispersion)
         
-    def plot_optics(self, var, option, n_points=1000):
+    def plot(self, var, option, n_points=1000):
         """
-        Plotting optical variables.
+        Plot optical variables.
     
         Parameters
         ----------
@@ -271,7 +273,7 @@ class Optics:
             
             label = ["D$_{x}$ (m)", "D'$_{x}$", "D$_{y}$ (m)", "D'$_{y}$"]
             
-            plt.ylabel(label[option_dict[option]])
+            ylabel = label[option_dict[option]]
          
         
         elif var=="beta" or var=="alpha" or var=="gamma":
@@ -284,17 +286,25 @@ class Optics:
             
             unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)"}
             
-            plt.ylabel(label_dict[var] + label_sup + unit[var])
+            ylabel = label_dict[var] + label_sup + unit[var]
   
                 
         else:
             raise ValueError("Variable name is not found.")
         
-        position = np.linspace(0, self.lattice.circumference, int(n_points))
+        if self.use_local_values is not True:
+            position = np.linspace(0, self.lattice.circumference, int(n_points))
+        else: 
+            position = np.linspace(0,1)
+            
         var_list = var_dict[var](position)[option_dict[option]]
-        plt.plot(position,var_list)
+        fig, ax = plt.subplots()
+        ax.plot(position,var_list)
            
-        plt.xlabel("position (m)")
+        ax.set_xlabel("position (m)")
+        ax.set_ylabel(ylabel)
+        
+        return fig
 
     
 class PhyisicalModel:
diff --git a/tracking/particles.py b/tracking/particles.py
index cb4c0468306f63b81fb9d45fc56f0127713b6f83..c6a802e7038fc4f4fb71cb806cad8777c8d1367d 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -336,26 +336,35 @@ class Bunch:
         plot_type : str {"j" , "sc"} 
             Type of the plot. The defualt value is "j" for a joint plot.
             Can be modified to "sc" for a scatter plot.
+            
+        Return
+        ------
+        fig : Figure
+            Figure object with the plot on it.
         """
         
         label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", 
                       "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"}
         scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1}
         
+        
         if plot_type == "sc":
-            plt.scatter(self.particles[x_var]*scale[x_var],
-                        self.particles[y_var]*scale[y_var])
-            plt.xlabel(label_dict[x_var])
-            plt.ylabel(label_dict[y_var])
+            fig, ax = plt.subplots()
+            ax.scatter(self.particles[x_var]*scale[x_var],
+                       self.particles[y_var]*scale[y_var])
+            ax.set_xlabel(label_dict[x_var])
+            ax.set_ylabel(label_dict[y_var])
         
         elif plot_type == "j": 
-            sns.jointplot(self.particles[x_var]*scale[x_var],
-                          self.particles[y_var]*scale[y_var],kind="kde")
+            fig = sns.jointplot(self.particles[x_var]*scale[x_var],
+                                self.particles[y_var]*scale[y_var],kind="kde")
             plt.xlabel(label_dict[x_var])
             plt.ylabel(label_dict[y_var])
             
         else: 
             raise ValueError("Plot type not recognised.")
+            
+        return fig
         
 class Beam:
     """
@@ -630,6 +639,11 @@ class Beam:
                 option = {"x","xp","y","yp","tau","delta"}.
             For "bunch_emit", option = {"x","y","s"}.
             The default is None.
+            
+        Return
+        ------
+        fig : Figure
+            Figure object with the plot on it.
         """
         
         var_dict = {"bunch_current":self.bunch_current,
@@ -639,6 +653,8 @@ class Beam:
                     "bunch_std":self.bunch_std,
                     "bunch_emit":self.bunch_emit}
         
+        fig, ax= plt.subplots()
+        
         if var == "bunch_mean" or var == "bunch_std":
             value_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
             scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
@@ -654,13 +670,13 @@ class Beam:
             where_is_nan = np.isnan(y_axis)
             y_axis[where_is_nan] = 0
             
-            plt.plot(np.arange(len(self.filling_pattern)),
+            ax.plot(np.arange(len(self.filling_pattern)),
                       y_axis*scale[value_dict[option]])
-            plt.xlabel('bunch number')
+            ax.set_xlabel('bunch number')
             if var == "bunch_mean":
-                plt.ylabel(label_mean[value_dict[option]])
+                ax.set_ylabel(label_mean[value_dict[option]])
             else: 
-                plt.ylabel(label_std[value_dict[option]])
+                ax.set_ylabel(label_std[value_dict[option]])
             
         elif var == "bunch_emit":
             value_dict = {"x":0, "y":1, "s":2}
@@ -672,33 +688,34 @@ class Beam:
             where_is_nan = np.isnan(y_axis)
             y_axis[where_is_nan] = 0
             
-            plt.plot(np.arange(len(self.filling_pattern)), 
+            ax.plot(np.arange(len(self.filling_pattern)), 
                      y_axis*scale[value_dict[option]])
             
             if option == "x": label_y = "hor. emittance (nm.rad)"
             elif option == "y": label_y = "ver. emittance (nm.rad)"
             elif option == "s": label_y =  "long. emittance (fm.rad)"
             
-            plt.xlabel('bunch number')
-            plt.ylabel(label_y)
+            ax.set_xlabel('bunch number')
+            ax.set_ylabel(label_y)
                 
         elif var=="bunch_current" or var=="bunch_charge" or var=="bunch_particle":
             scale = {"bunch_current":1e3, "bunch_charge":1e9, 
                      "bunch_particle":1}
             
-            plt.plot(np.arange(len(self.filling_pattern)), var_dict[var]*
+            ax.plot(np.arange(len(self.filling_pattern)), var_dict[var]*
                      scale[var]) 
-            plt.xlabel('bunch number')
+            ax.set_xlabel('bunch number')
             
             if var == "bunch_current": label_y = "bunch current (mA)"
             elif var == "bunch_charge": label_y = "bunch chagre (nC)"
             else: label_y = "number of particles"
 
-            plt.ylabel(label_y)             
+            ax.set_ylabel(label_y)             
     
         elif var == "current" or var=="charge" or var=="particle_number":
-            print("'{0}'is a total value and cannot be plotted".format(var))
+            raise ValueError("'{0}'is a total value and cannot be plotted."
+                             .format(var))
        
-        
+        return fig