diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index a7ddf85ea4b324783bbf2aacad89dfa345398735..9dfd6910fb1baa41950d9dea10eaf7b4b3063740 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -366,7 +366,7 @@ class Bunch:
         ax = fig.gca()
         ax.plot(center, profile)
         
-    def plot_phasespace(self, x_var="tau", y_var="delta", plot_type="j"):
+    def plot_phasespace(self, x_var="tau", y_var="delta", kind="scatter"):
         """
         Plot phase space.
         
@@ -376,9 +376,9 @@ class Bunch:
             Dimension to plot on horizontal axis.
         y_var : str 
             Dimension to plot on vertical axis.
-        plot_type : str {"j" , "sc"} 
-            Type of the plot. The defualt value is "j" for a joint plot.
-            Can be modified to "sc" for a scatter plot.
+        kind : str, {"scatter", "kde", "hist", "hex", "reg", "resid"}
+            Kind of plot to draw. See seeborn documentation.
+            The defualt value is "scatter".
             
         Return
         ------
@@ -389,23 +389,12 @@ class Bunch:
         label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", 
                       "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"}
         scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1}
-        
-        
-        if plot_type == "sc":
-            fig, ax = plt.subplots()
-            ax.scatter(self.particles[x_var]*scale[x_var],
-                       self.particles[y_var]*scale[y_var])
-            ax.set_xlabel(label_dict[x_var])
-            ax.set_ylabel(label_dict[y_var])
-        
-        elif plot_type == "j": 
-            fig = sns.jointplot(self.particles[x_var]*scale[x_var],
-                                self.particles[y_var]*scale[y_var],kind="kde")
-            plt.xlabel(label_dict[x_var])
-            plt.ylabel(label_dict[y_var])
-            
-        else: 
-            raise ValueError("Plot type not recognised.")
+
+        fig = sns.jointplot(x=self.particles[x_var]*scale[x_var],
+                            y=self.particles[y_var]*scale[y_var],
+                            kind=kind)
+        plt.xlabel(label_dict[x_var])
+        plt.ylabel(label_dict[y_var])
             
         return fig