From b4350aa3e719be59a47befbfb59ef50623ede702 Mon Sep 17 00:00:00 2001 From: Gamelin Alexis <gamelin@synchrotron-soleil.fr> Date: Fri, 24 Apr 2020 17:46:22 +0200 Subject: [PATCH] Rebase on master and rework plot_phasespacedata Rebase test_plot on master plot_phasespacedata -> only "particles" dataset is plotted plot_phasespacedata -> mandatory arguments --- tracking/particles.py | 25 +++++---- tracking/plotting.py | 117 ++++++++++++++++++------------------------ 2 files changed, 64 insertions(+), 78 deletions(-) diff --git a/tracking/particles.py b/tracking/particles.py index fc52fa8..cb4c046 100644 --- a/tracking/particles.py +++ b/tracking/particles.py @@ -88,6 +88,8 @@ class Bunch: ------- init_gaussian(cov=None, mean=None, **kwargs) Initialize bunch particles with 6D gaussian phase space. + plot_phasespace(x_var="tau", y_var="delta", plot_type="j") + Plot phase space. References ---------- @@ -321,18 +323,18 @@ class Bunch: ax = fig.gca() ax.plot(bins.mid, profile) - def getplot(self,x,y,p_type): + def plot_phasespace(self, x_var="tau", y_var="delta", plot_type="j"): """ - Plot longitudinal phase space. + Plot phase space. Parameters ---------- x_var : str - name from Bunch object to plot on horizontal axis. + Dimension to plot on horizontal axis. y_var : str - name from Bunch object to plot on ver. axis. + Dimension to plot on vertical axis. plot_type : str {"j" , "sc"} - type of the plot. The defualt value is "j" for a joint plot. + Type of the plot. The defualt value is "j" for a joint plot. Can be modified to "sc" for a scatter plot. """ @@ -352,7 +354,8 @@ class Bunch: plt.xlabel(label_dict[x_var]) plt.ylabel(label_dict[y_var]) - else: raise ValueError("Plot type not recognised") + else: + raise ValueError("Plot type not recognised.") class Beam: """ @@ -390,7 +393,6 @@ class Beam: Status of MPI parallelisation, should not be changed directly but with mpi_init() and mpi_close() - Methods ------ init_beam(filling_pattern, current_per_bunch=1e-3, mp_per_bunch=1e3) @@ -404,6 +406,8 @@ class Beam: all processors. Rather slow mpi_close() Call mpi_gather and switch off MPI parallelisation + plot(var, option=None) + Plot variables with respect to bunch number. """ def __init__(self, ring, bunch_list=None): @@ -610,13 +614,13 @@ class Beam: self.mpi_switch = False self.mpi = None - def plot_bunchnumber(self, var, option=None): + def plot(self, var, option=None): """ - Plot varviables with respect to bunch number. + Plot variables with respect to bunch number. Parameters ---------- - var : str {"bunch_currebt", "bunch_charge", "bunch_particle", + var : str {"bunch_current", "bunch_charge", "bunch_particle", "bunch_mean", "bunch_std", "bunch_emit"} Variable to be plotted. option : str, optional @@ -626,7 +630,6 @@ class Beam: option = {"x","xp","y","yp","tau","delta"}. For "bunch_emit", option = {"x","y","s"}. The default is None. - """ var_dict = {"bunch_current":self.bunch_current, diff --git a/tracking/plotting.py b/tracking/plotting.py index ac143f5..9ae1449 100644 --- a/tracking/plotting.py +++ b/tracking/plotting.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """ -Module for plotting Bunch and Beam objects. +Module for plotting the data recorded by the monitor module during the +tracking. @author: Watanyu Foosang @Date: 10/04/2020 @@ -11,10 +12,9 @@ import matplotlib.pyplot as plt import seaborn as sns import h5py as hp - def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): """ - Plot the evolution of the variables from the Beam object. + Plot data recorded from a BunchMonitor. Parameters ---------- @@ -82,10 +82,10 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): file.close() -def plot_phasespacedata(filename, bunch_number, dataset, x_var=None, - y_var=None, turn=None, only_alive=True): +def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, + only_alive=True): """ - Plot data from PhaseSpaceData_0 group of the HDF5 file. + Plot data recorded from a PhaseSpaceMonitor. Parameters ---------- @@ -93,78 +93,61 @@ def plot_phasespacedata(filename, bunch_number, dataset, x_var=None, Name of the HDF5 file that contains the data. bunch_number : int Number of the bunch whose data has been saved in the HDF5 file. - This has to be identical to 'bunch_number' parameter in 'PhaseSpaceMonitor' object. - dataset : str {'alive', 'partcicles'} - HDF5 file's dataset to be plotted. - x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}, optional - If dataset is "particles", the variables to be plotted on the horizontal - and the vertical axes need to be specified. - turn : int, optional + This has to be identical to 'bunch_number' parameter in + 'PhaseSpaceMonitor' object. + x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"} + If dataset is "particles", the variables to be plotted on the + horizontal and the vertical axes need to be specified. + turn : int Turn at which the data will be extracted. - only_alive : bool - When only_alive is True, only alive particles are plotted and dead particles will be discarded. - + only_alive : bool, optional + When only_alive is True, only alive particles are plotted and dead + particles will be discarded. """ file = hp.File(filename, "r") group = "PhaseSpaceData_{0}".format(bunch_number) - - if dataset == "alive": - alive_at_a_time = [] - for i in range (len(file[group]["time"])): - alive_at_a_time.append(np.sum(file[group][dataset][:,i])) - - plt.plot(file[group]["time"],alive_at_a_time) - plt.xlabel("number of turns") - plt.ylabel("number of alive particles") - - elif dataset == "particles": - # Specify the parameter - # x -> 0 - # xp -> 1 - # y -> 2 - # yp -> 3 - # tau -> 4 - # delta -> 5 - - var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} - scale = [1e3,1e3,1e3,1e3,1e12,1] - label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)", - "$\\delta$"] - - - # find the index of "turn" in time array - turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) - - if len(turn_index[0]) == 0: - raise ValueError("Turn {0} is not found. Enter turn from {1}.". - format(turn, file[group]["time"][:])) + dataset = "particles" - else : pass - - path = file[group][dataset] - - - if only_alive is False: - # format : sns.jointplot(x_axis, yaxis, kind) - x_axis = path[:,var_dict[x_var],turn_index[0][0]] - y_axis = path[:,var_dict[y_var],turn_index[0][0]] + # Specify the parameter + # x -> 0 + # xp -> 1 + # y -> 2 + # yp -> 3 + # tau -> 4 + # delta -> 5 + + var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + scale = [1e3,1e3,1e3,1e3,1e12,1] + label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)", + "$\\delta$"] + + # find the index of "turn" in time array + turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) + + if len(turn_index[0]) == 0: + raise ValueError("Turn {0} is not found. Enter turn from {1}.". + format(turn, file[group]["time"][:])) + + path = file[group][dataset] + if only_alive is False: + # format : sns.jointplot(x_axis, yaxis, kind) + x_axis = path[:,var_dict[x_var],turn_index[0][0]] + y_axis = path[:,var_dict[y_var],turn_index[0][0]] - elif only_alive is True: - alive_index = np.where(file[group]["alive"][:,turn_index]) + elif only_alive is True: + alive_index = np.where(file[group]["alive"][:,turn_index]) - x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]] - y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]] - - - sns.jointplot(x_axis*scale[var_dict[x_var]], - y_axis*scale[var_dict[y_var]], kind="kde") - - plt.xlabel(label[var_dict[x_var]]) - plt.ylabel(label[var_dict[y_var]]) + x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]] + y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]] + sns.jointplot(x_axis*scale[var_dict[x_var]], + y_axis*scale[var_dict[y_var]], kind="kde") + + plt.xlabel(label[var_dict[x_var]]) + plt.ylabel(label[var_dict[y_var]]) file.close() -- GitLab