diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py index 0a9cad7e174ceb1d11d555b3355f5949d3d51ab0..f28b0df341fc7756b245550a80f8ff043aded3bd 100644 --- a/tracking/monitors/plotting.py +++ b/tracking/monitors/plotting.py @@ -32,6 +32,11 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"): needs to be specified. The default is None. x_var : str, optional Variable to be plotted on the horizontal axis. The default is "time". + + Return + ------ + fig : Figure + Figure object with the plot on it. """ @@ -43,9 +48,10 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"): for i in range (len(path["time"])): total_current.append(np.sum(path["current"][:,i])*1e3) - plt.plot(path["time"],total_current) - plt.xlabel("Number of turns") - plt.ylabel("total current (mA)") + fig, ax = plt.subplots() + ax.plot(path["time"],total_current) + ax.set_xlabel("Number of turns") + ax.set_ylabel("total current (mA)") elif dataset == "emit": option_dict = {"x":0, "y":1, "s":2} #input option @@ -60,75 +66,50 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"): for i in range (len(path["time"])): mean_emit.append(np.mean(path["emit"][axis,:,i])*scale[axis]) - plt.plot(path["time"],mean_emit) - label_sup = "avg. " + fig, ax = plt.subplots() + ax.plot(path["time"],mean_emit) elif stat_var == "std": std_emit = [] for i in range (len(path["time"])): std_emit.append(np.std(path["emit"][axis,:,i])*scale[axis]) - - plt.plot(path["time"],std_emit) - label_sup = "std. " + + fig, ax = plt.subplots() + ax.plot(path["time"],std_emit) - plt.xlabel("Number of turns") - plt.ylabel(label_sup + label[axis]) + ax.set_xlabel("Number of turns") + ax.set_ylabel(stat_var+" " + label[axis]) - elif dataset == "mean": + elif dataset == "mean" or dataset == "std": option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} axis = option_dict[option] scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"] + fig, ax = plt.subplots() if stat_var == "mean": - mean_mean = [] + mean_list = [] for i in range (len(path["time"])): - mean_mean.append(np.mean(path["mean"][axis,:,i]*scale[axis])) + mean_list.append(np.mean(path[dataset][axis,:,i]*scale[axis])) - plt.plot(path["time"],mean_mean) - label_sup = "avg. " + ax.plot(path["time"],mean_list) + label_sup = {"mean":"", "std":"std of "} # input stat_var elif stat_var == "std": - std_mean = [] + std_list = [] for i in range (len(path["time"])): - std_mean.append(np.std(path["mean"][axis,:,i]*scale[axis])) + std_list.append(np.std(path[dataset][axis,:,i]*scale[axis])) - plt.plot(path["time"],std_mean) - label_sup = "std. of avg. " - - plt.xlabel("Number of turns") - plt.ylabel(label_sup + label[axis]) - - elif dataset == "std": - option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} - axis = option_dict[option] - scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] - label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", - "$\\tau$ (ps)", "$\\delta$"] - - if stat_var == "mean": - mean_std = [] - for i in range (len(path["time"])): - mean_std.append(np.mean(path["std"][axis,:,i]*scale[axis])) - - plt.plot(path["time"],mean_std) - label_sup = "std. " + ax.plot(path["time"],std_list) + label_sup = {"mean":"", "std":"std of "} #input stat_var - elif stat_var == "std": - std_std = [] - for i in range (len(path["time"])): - std_std.append(np.std(path["std"][axis,:,i]*scale[axis])) - - plt.plot(path["time"],std_std) - label_sup = "std. of std. " - - plt.xlabel("Number of turns") - plt.ylabel(label_sup + label[axis]) + ax.set_xlabel("Number of turns") + ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis]) file.close() - - + return fig + def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): """ Plot data recorded from a BunchMonitor. @@ -149,6 +130,11 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): for "mean" and "std", option = {"x","xp","y","yp","tau","delta"} x_var : str, optional Variable to be plotted on the horizontal axis. The default is "time". + + Return + ------ + fig : Figure + Figure object with the plot on it. """ @@ -161,43 +147,32 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): label = "current (mA)" elif dataset == "emit": - # Specifying the axis - # horizontal axis (x) -> 0 - # vertical axis (y) -> 1 - # longitudinal axis (s) -> 2 - - emit_axis = {"x":0, "y":1, "s":2} + option_dict = {"x":0, "y":1, "s":2} - y_var = file[group][dataset][emit_axis[option]]*1e9 + y_var = file[group][dataset][option_dict[option]]*1e9 if option == "x": label = "hor. emittance (nm.rad)" elif option == "y": label = "ver. emittance (nm.rad)" elif option == "s": label = "long. emittance (nm.rad)" - elif dataset == "mean" or "std": - # Specify the variable - # x -> 0 - # xp -> 1 - # y -> 2 - # yp -> 3 - # tau -> 4 - # delta -> 5 - - var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + elif dataset == "mean" or "std": + option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"] - y_var = file[group][dataset][var_dict[option]]*scale[var_dict[option]] - label = label_list[var_dict[option]] + y_var = file[group][dataset][option_dict[option]]\ + *scale[option_dict[option]] + label = label_list[option_dict[option]] - - plt.plot(file[group]["time"][:],y_var) - plt.xlabel("number of turns") - plt.ylabel(label) + fig, ax = plt.subplots() + ax.plot(file[group]["time"][:],y_var) + ax.set_xlabel("number of turns") + ax.set_ylabel(label) file.close() + return fig def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, only_alive=True): @@ -220,6 +195,11 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, only_alive : bool, optional When only_alive is True, only alive particles are plotted and dead particles will be discarded. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ file = hp.File(filename, "r") @@ -227,15 +207,7 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, group = "PhaseSpaceData_{0}".format(bunch_number) dataset = "particles" - # Specify the parameter - # x -> 0 - # xp -> 1 - # y -> 2 - # yp -> 3 - # tau -> 4 - # delta -> 5 - - var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} scale = [1e3,1e3,1e3,1e3,1e12,1] label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)", "$\\delta$"] @@ -251,24 +223,20 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, if only_alive is False: # format : sns.jointplot(x_axis, yaxis, kind) - x_axis = path[:,var_dict[x_var],turn_index[0][0]] - y_axis = path[:,var_dict[y_var],turn_index[0][0]] + x_axis = path[:,option_dict[x_var],turn_index[0][0]] + y_axis = path[:,option_dict[y_var],turn_index[0][0]] elif only_alive is True: alive_index = np.where(file[group]["alive"][:,turn_index]) - x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]] - y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]] + x_axis = path[alive_index[0],option_dict[x_var],turn_index[0][0]] + y_axis = path[alive_index[0],option_dict[y_var],turn_index[0][0]] - sns.jointplot(x_axis*scale[var_dict[x_var]], - y_axis*scale[var_dict[y_var]], kind="kde") + fig = sns.jointplot(x_axis*scale[option_dict[x_var]], + y_axis*scale[option_dict[y_var]], kind="kde") - plt.xlabel(label[var_dict[x_var]]) - plt.ylabel(label[var_dict[y_var]]) + plt.xlabel(label[option_dict[x_var]]) + plt.ylabel(label[option_dict[y_var]]) file.close() - - - - - + return fig diff --git a/tracking/optics.py b/tracking/optics.py index be326c44431ae43988080dbc4571d4ce922de41b..da6ae20aec203e9828cf92d9066fda9f79345357 100644 --- a/tracking/optics.py +++ b/tracking/optics.py @@ -55,6 +55,8 @@ class Optics: Return gamma functions at specific locations given by position. dispersion(position) Return dispersion functions at specific locations given by position. + plot(self, var, option, n_points=1000) + Plot optical variables. """ def __init__(self, lattice_file=None, local_beta=None, local_alpha=None, @@ -245,9 +247,9 @@ class Optics: self.dispY(position), self.disppY(position)] return np.array(dispersion) - def plot_optics(self, var, option, n_points=1000): + def plot(self, var, option, n_points=1000): """ - Plotting optical variables. + Plot optical variables. Parameters ---------- @@ -271,7 +273,7 @@ class Optics: label = ["D$_{x}$ (m)", "D'$_{x}$", "D$_{y}$ (m)", "D'$_{y}$"] - plt.ylabel(label[option_dict[option]]) + ylabel = label[option_dict[option]] elif var=="beta" or var=="alpha" or var=="gamma": @@ -284,17 +286,25 @@ class Optics: unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)"} - plt.ylabel(label_dict[var] + label_sup + unit[var]) + ylabel = label_dict[var] + label_sup + unit[var] else: raise ValueError("Variable name is not found.") - position = np.linspace(0, self.lattice.circumference, int(n_points)) + if self.use_local_values is not True: + position = np.linspace(0, self.lattice.circumference, int(n_points)) + else: + position = np.linspace(0,1) + var_list = var_dict[var](position)[option_dict[option]] - plt.plot(position,var_list) + fig, ax = plt.subplots() + ax.plot(position,var_list) - plt.xlabel("position (m)") + ax.set_xlabel("position (m)") + ax.set_ylabel(ylabel) + + return fig class PhyisicalModel: diff --git a/tracking/particles.py b/tracking/particles.py index cb4c0468306f63b81fb9d45fc56f0127713b6f83..c6a802e7038fc4f4fb71cb806cad8777c8d1367d 100644 --- a/tracking/particles.py +++ b/tracking/particles.py @@ -336,26 +336,35 @@ class Bunch: plot_type : str {"j" , "sc"} Type of the plot. The defualt value is "j" for a joint plot. Can be modified to "sc" for a scatter plot. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"} scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1} + if plot_type == "sc": - plt.scatter(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var]) - plt.xlabel(label_dict[x_var]) - plt.ylabel(label_dict[y_var]) + fig, ax = plt.subplots() + ax.scatter(self.particles[x_var]*scale[x_var], + self.particles[y_var]*scale[y_var]) + ax.set_xlabel(label_dict[x_var]) + ax.set_ylabel(label_dict[y_var]) elif plot_type == "j": - sns.jointplot(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var],kind="kde") + fig = sns.jointplot(self.particles[x_var]*scale[x_var], + self.particles[y_var]*scale[y_var],kind="kde") plt.xlabel(label_dict[x_var]) plt.ylabel(label_dict[y_var]) else: raise ValueError("Plot type not recognised.") + + return fig class Beam: """ @@ -630,6 +639,11 @@ class Beam: option = {"x","xp","y","yp","tau","delta"}. For "bunch_emit", option = {"x","y","s"}. The default is None. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ var_dict = {"bunch_current":self.bunch_current, @@ -639,6 +653,8 @@ class Beam: "bunch_std":self.bunch_std, "bunch_emit":self.bunch_emit} + fig, ax= plt.subplots() + if var == "bunch_mean" or var == "bunch_std": value_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] @@ -654,13 +670,13 @@ class Beam: where_is_nan = np.isnan(y_axis) y_axis[where_is_nan] = 0 - plt.plot(np.arange(len(self.filling_pattern)), + ax.plot(np.arange(len(self.filling_pattern)), y_axis*scale[value_dict[option]]) - plt.xlabel('bunch number') + ax.set_xlabel('bunch number') if var == "bunch_mean": - plt.ylabel(label_mean[value_dict[option]]) + ax.set_ylabel(label_mean[value_dict[option]]) else: - plt.ylabel(label_std[value_dict[option]]) + ax.set_ylabel(label_std[value_dict[option]]) elif var == "bunch_emit": value_dict = {"x":0, "y":1, "s":2} @@ -672,33 +688,34 @@ class Beam: where_is_nan = np.isnan(y_axis) y_axis[where_is_nan] = 0 - plt.plot(np.arange(len(self.filling_pattern)), + ax.plot(np.arange(len(self.filling_pattern)), y_axis*scale[value_dict[option]]) if option == "x": label_y = "hor. emittance (nm.rad)" elif option == "y": label_y = "ver. emittance (nm.rad)" elif option == "s": label_y = "long. emittance (fm.rad)" - plt.xlabel('bunch number') - plt.ylabel(label_y) + ax.set_xlabel('bunch number') + ax.set_ylabel(label_y) elif var=="bunch_current" or var=="bunch_charge" or var=="bunch_particle": scale = {"bunch_current":1e3, "bunch_charge":1e9, "bunch_particle":1} - plt.plot(np.arange(len(self.filling_pattern)), var_dict[var]* + ax.plot(np.arange(len(self.filling_pattern)), var_dict[var]* scale[var]) - plt.xlabel('bunch number') + ax.set_xlabel('bunch number') if var == "bunch_current": label_y = "bunch current (mA)" elif var == "bunch_charge": label_y = "bunch chagre (nC)" else: label_y = "number of particles" - plt.ylabel(label_y) + ax.set_ylabel(label_y) elif var == "current" or var=="charge" or var=="particle_number": - print("'{0}'is a total value and cannot be plotted".format(var)) + raise ValueError("'{0}'is a total value and cannot be plotted." + .format(var)) - + return fig