From d756f5a2d6b6d1b6daed5faf4bbb21587d9aeffa Mon Sep 17 00:00:00 2001 From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr> Date: Mon, 19 Jun 2023 16:33:04 +0200 Subject: [PATCH] [Fix] Bunch.plot_phasespace Now works correctly for up to date version of seaborn. --- mbtrack2/tracking/particles.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py index a7ddf85..9dfd691 100644 --- a/mbtrack2/tracking/particles.py +++ b/mbtrack2/tracking/particles.py @@ -366,7 +366,7 @@ class Bunch: ax = fig.gca() ax.plot(center, profile) - def plot_phasespace(self, x_var="tau", y_var="delta", plot_type="j"): + def plot_phasespace(self, x_var="tau", y_var="delta", kind="scatter"): """ Plot phase space. @@ -376,9 +376,9 @@ class Bunch: Dimension to plot on horizontal axis. y_var : str Dimension to plot on vertical axis. - plot_type : str {"j" , "sc"} - Type of the plot. The defualt value is "j" for a joint plot. - Can be modified to "sc" for a scatter plot. + kind : str, {"scatter", "kde", "hist", "hex", "reg", "resid"} + Kind of plot to draw. See seeborn documentation. + The defualt value is "scatter". Return ------ @@ -389,23 +389,12 @@ class Bunch: label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"} scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1} - - - if plot_type == "sc": - fig, ax = plt.subplots() - ax.scatter(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var]) - ax.set_xlabel(label_dict[x_var]) - ax.set_ylabel(label_dict[y_var]) - - elif plot_type == "j": - fig = sns.jointplot(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var],kind="kde") - plt.xlabel(label_dict[x_var]) - plt.ylabel(label_dict[y_var]) - - else: - raise ValueError("Plot type not recognised.") + + fig = sns.jointplot(x=self.particles[x_var]*scale[x_var], + y=self.particles[y_var]*scale[y_var], + kind=kind) + plt.xlabel(label_dict[x_var]) + plt.ylabel(label_dict[y_var]) return fig -- GitLab