diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py index a7ddf85ea4b324783bbf2aacad89dfa345398735..9dfd6910fb1baa41950d9dea10eaf7b4b3063740 100644 --- a/mbtrack2/tracking/particles.py +++ b/mbtrack2/tracking/particles.py @@ -366,7 +366,7 @@ class Bunch: ax = fig.gca() ax.plot(center, profile) - def plot_phasespace(self, x_var="tau", y_var="delta", plot_type="j"): + def plot_phasespace(self, x_var="tau", y_var="delta", kind="scatter"): """ Plot phase space. @@ -376,9 +376,9 @@ class Bunch: Dimension to plot on horizontal axis. y_var : str Dimension to plot on vertical axis. - plot_type : str {"j" , "sc"} - Type of the plot. The defualt value is "j" for a joint plot. - Can be modified to "sc" for a scatter plot. + kind : str, {"scatter", "kde", "hist", "hex", "reg", "resid"} + Kind of plot to draw. See seeborn documentation. + The defualt value is "scatter". Return ------ @@ -389,23 +389,12 @@ class Bunch: label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"} scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1} - - - if plot_type == "sc": - fig, ax = plt.subplots() - ax.scatter(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var]) - ax.set_xlabel(label_dict[x_var]) - ax.set_ylabel(label_dict[y_var]) - - elif plot_type == "j": - fig = sns.jointplot(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var],kind="kde") - plt.xlabel(label_dict[x_var]) - plt.ylabel(label_dict[y_var]) - - else: - raise ValueError("Plot type not recognised.") + + fig = sns.jointplot(x=self.particles[x_var]*scale[x_var], + y=self.particles[y_var]*scale[y_var], + kind=kind) + plt.xlabel(label_dict[x_var]) + plt.ylabel(label_dict[y_var]) return fig