diff --git a/tracking/element.py b/tracking/element.py index 8692856c1c54e9f17f7a40fbed11bc693b0273f9..81f506cac6478144ed78ee60a5c588fb972ce863 100644 --- a/tracking/element.py +++ b/tracking/element.py @@ -9,9 +9,13 @@ included in the tracking. """ import numpy as np +import pandas as pd from abc import ABCMeta, abstractmethod from functools import wraps from mbtrack2.tracking.particles import Beam +from scipy import signal +import matplotlib.pyplot as plt +from scipy.interpolate import interp1d class Element(metaclass=ABCMeta): """ @@ -95,6 +99,203 @@ class LongitudinalMap(Element): bunch["delta"] -= self.ring.U0 / self.ring.E0 bunch["tau"] -= self.ring.ac * self.ring.T0 * bunch["delta"] +class WakePotential(Element): + """ + Resonator model based wake potential calculation for one turn. + + Parameters + ---------- + ring : Synchrotron object. + Q_factor : float, optional + Resonator quality factor. The default value is 1. + f_res : float, optional + Resonator resonance frequency in [Hz]. The default value is 10e9 Hz. + R_shunt : float, optional + Resonator shunt impedance in [Ohm]. The default value is 100 Ohm. + n_bin : int, optional + Number of bins for constructing the longitudinal bunch profile. + The default is 65. + + Attributes + ---------- + rho : array of shape (n_bin, ) + Bunch charge density profile. + tau : array of shape (n_bin + time_extra, ) + Time array starting from the head of the bunch until the wake tail + called timestop. + + The length of time_extra is determined by the last position of the + bunch time_bunch[-1], timestop, and the mean bin width of the bunch + profile mean_bin_size as + len(time_extra) = (timestop - time_bunch[-1]) / mean_bin_size + W_long : array of shape (n_bin + time_extra, ) + Wakefunction profile. + W_p : array of shape (n_bin + time_extra, ) + Wake potential profile. + wp : array of shape (mp_number, ) + Wake potential exerted on each macro-particle. + + Methods + ------- + charge_density(bunch, n_bin=65) + Calculate bunch charge density. + plot(self, var, plot_rho=True) + Plotting wakefunction or wake potential. + track(bunch) + Tracking method for the element. + + """ + + def __init__(self, ring, Q_factor=1, f_res=10e9, R_shunt=100, n_bin=65): + self.ring = ring + self.n_bin = n_bin + + self.Q_factor = Q_factor + self.omega_res = 2*np.pi*f_res + self.R_shunt = R_shunt + + if Q_factor >= 0.5: + self.Q_factor_p = np.sqrt(self.Q_factor**2 - 0.25) + self.omega_res_p = (self.omega_res*self.Q_factor_p)/self.Q_factor + else: + self.Q_factor_pp = np.sqrt(0.25 - self.Q_factor**2) + self.omega_res_p = (self.omega_res*self.Q_factor_pp)/self.Q_factor + + def charge_density(self, bunch, n_bin): + self.bins = bunch.binning(n_bin=self.n_bin) + self.bin_data = self.bins[2] + self.bin_size = self.bins[0].length + + self.rho = bunch.charge_per_mp*self.bin_data/ \ + (self.bin_size*bunch.charge) + + def init_timestop(self): + self.timestop = round(np.log(1000)/self.omega_res*2*self.Q_factor, 12) + + def time_array(self): + time_bunch = self.bins[0].mid + mean_bin_size = np.mean(self.bin_size) + time_extra = np.arange(start = time_bunch[-1]+mean_bin_size, + stop = self.timestop, step = mean_bin_size) + + self.tau = np.concatenate((time_bunch,time_extra)) + + def long_wakefunction(self): + w_list = [] + if self.Q_factor >= 0.5: + for t in self.tau: + if t >= 0: + w_long = -(self.omega_res*self.R_shunt/self.Q_factor)*\ + np.exp(-self.omega_res*t/(2*self.Q_factor))*\ + (np.cos(self.omega_res_p*t)-\ + np.sin(self.omega_res_p*t)/(2*self.Q_factor_p)) + else: + w_long = 0 + + w_list.append(w_long) + + elif self.Q_factor < 0.5: + for t in self.tau: + if t >= 0: + w_long = -(self.omega_res*self.R_shunt/self.Q_factor)*\ + np.exp(-self.omega_res*t/(2*self.Q_factor))*\ + (np.cosh(self.omega_res_p*t)-\ + np.sinh(self.omega_res_p*t)/(2*self.Q_factor_pp)) + else: + w_long = 0 + + w_list.append(w_long) + + self.W_long = np.array(w_list) + + def wake_potential(self): + self.W_p = signal.convolve(self.W_long*1e-12, self.rho, mode="same") + + def plot(self, var, plot_rho=True): + """ + Plotting wakefunction or wake potential. + + Parameters + ---------- + var : {'W_p', 'W_long' } + If 'W_p', the wake potential is plotted. + If 'W_long', the wakefunction is plotted. + plot_rho : bool, optional + Overlay the bunch charge density profile on the plot. + The default is True. + + Returns + ------- + fig : Figure + Figure object with the plot on it. + + """ + + fig, ax = plt.subplots() + + if var == "W_p": + ax.plot(self.tau*1e12, self.W_p*1e-12) + + ax.set_xlabel("$\\tau$ (ps)") + ax.set_ylabel("W$_p$ (V/pC)") + + elif var == "W_long": + ax.plot(self.tau*1e12, self.W_long*1e-12) + ax.set_xlabel("$\\tau$ (ps)") + ax.set_ylabel("W$_{||}$ ($\\Omega$/ps)") + + if plot_rho is True: + rho_array = np.array(self.rho) + rho_rescaled = rho_array/max(rho_array)*max(self.W_p) + + ax.plot(self.bins[0].mid*1e12, rho_rescaled*1e-12) + + else: + pass + + return fig + + def check_wake_tail(self): + """ + Checking whether the full wakefunction is obtained by the calculated + initial timestop. + + """ + + ratio = np.abs(min(self.W_long) / self.W_long[-6:-1]) + while any(ratio < 1000): + # Initial timestop is too short. + # Extending timestop by 50 ps and recalculating." + self.timestop += 50e-12 + self.time_array() + self.long_wakefunction() + ratio = np.abs(min(self.W_long) / self.W_long[-6:-1]) + + @Element.parallel + def track(self, bunch): + """ + Tracking method for the element. + No bunch to bunch interaction, so written for Bunch objects and + @Element.parallel is used to handle Beam objects. + + Parameters + ---------- + bunch : Bunch or Beam object. + + """ + + self.charge_density(bunch, n_bin = self.n_bin) + self.init_timestop() + self.time_array() + self.long_wakefunction() + self.check_wake_tail() + self.wake_potential() + + f = interp1d(self.tau, self.W_p, fill_value = 0, bounds_error = False) + self.wp = f(bunch["tau"]) + + bunch["delta"] += self.wp * bunch.charge / self.ring.E0 + class SynchrotronRadiation(Element): """ Element to handle synchrotron radiation, radiation damping and quantum @@ -204,4 +405,4 @@ class TransverseMap(Element): bunch["x"] = x bunch["xp"] = xp bunch["y"] = y - bunch["yp"] = yp \ No newline at end of file + bunch["yp"] = yp diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py index 9ae1449936bf8c12e56b951c543bb48391a3e677..273832166adbda6625b9a9c9db891628943e876d 100644 --- a/tracking/monitors/plotting.py +++ b/tracking/monitors/plotting.py @@ -9,29 +9,134 @@ tracking. import numpy as np import matplotlib.pyplot as plt +import matplotlib as mpl import seaborn as sns import h5py as hp +import random +def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"): + """ + Plot data recorded by BeamMonitor. + + Parameters + ---------- + filename : str + Name of the HDF5 file that contains the data. + dataset : {"current","emit","mean","std"} + HDF5 file's dataset to be plotted. + option : str, optional + If dataset is "emit", "mean", or "std", the variable name to be plotted + needs to be specified : + for "emit", option = {"x","y","s"} + for "mean" and "std", option = {"x","xp","y","yp","tau","delta"} + stat_var : {"mean", "std"}, optional + Statistical value of option. Except when dataset = "current", stat_var + needs to be specified. The default is None. + x_var : str, optional + Variable to be plotted on the horizontal axis. The default is "time". + + Return + ------ + fig : Figure + Figure object with the plot on it. + + """ + + file = hp.File(filename, "r") + path = file["Beam"] + + if dataset == "current": + total_current = [] + for i in range (len(path["time"])): + total_current.append(np.sum(path["current"][:,i])*1e3) + + fig, ax = plt.subplots() + ax.plot(path["time"],total_current) + ax.set_xlabel("Number of turns") + ax.set_ylabel("total current (mA)") + + elif dataset == "emit": + option_dict = {"x":0, "y":1, "s":2} #input option + axis = option_dict[option] + scale = [1e12, 1e12, 1e15] + label = ["$\\epsilon_{x}$ (pm.rad)", + "$\\epsilon_{y}$ (pm.rad)", + "$\\epsilon_{s}$ (fm.rad)"] + + if stat_var == "mean": + mean_emit = [] + for i in range (len(path["time"])): + mean_emit.append(np.mean(path["emit"][axis,:,i])*scale[axis]) + + fig, ax = plt.subplots() + ax.plot(path["time"],mean_emit) + + elif stat_var == "std": + std_emit = [] + for i in range (len(path["time"])): + std_emit.append(np.std(path["emit"][axis,:,i])*scale[axis]) + + fig, ax = plt.subplots() + ax.plot(path["time"],std_emit) + + ax.set_xlabel("Number of turns") + ax.set_ylabel(stat_var+" " + label[axis]) + + elif dataset == "mean" or dataset == "std": + option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + axis = option_dict[option] + scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] + label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", + "$\\tau$ (ps)", "$\\delta$"] + + fig, ax = plt.subplots() + if stat_var == "mean": + mean_list = [] + for i in range (len(path["time"])): + mean_list.append(np.mean(path[dataset][axis,:,i]*scale[axis])) + + ax.plot(path["time"],mean_list) + label_sup = {"mean":"", "std":"std of "} # input stat_var + + elif stat_var == "std": + std_list = [] + for i in range (len(path["time"])): + std_list.append(np.std(path[dataset][axis,:,i]*scale[axis])) + + ax.plot(path["time"],std_list) + label_sup = {"mean":"", "std":"std of "} #input stat_var + + ax.set_xlabel("Number of turns") + ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis]) + + file.close() + return fig + def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): """ - Plot data recorded from a BunchMonitor. + Plot data recorded by BunchMonitor. Parameters ---------- - filename : str + filename : str Name of the HDF5 file that contains the data. bunch_number : int - The bunch number whose data has been saved in the HDF5 file. - This has to be identical to 'bunch_number' parameter in 'BunchMonitor' object. - detaset : str {"current","emit","mean","std"} + Bunch to plot. This has to be identical to 'bunch_number' parameter in + 'BunchMonitor' object. + detaset : {"current","emit","mean","std"} HDF5 file's dataset to be plotted. option : str, optional If dataset is "emit", "mean", or "std", the variable name to be plotted needs to be specified : for "emit", option = {"x","y","s"} for "mean" and "std", option = {"x","xp","y","yp","tau","delta"} - x_var : str, optional + x_var : {"time", "current"}, optional Variable to be plotted on the horizontal axis. The default is "time". + + Return + ------ + fig : Figure + Figure object with the plot on it. """ @@ -44,56 +149,57 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"): label = "current (mA)" elif dataset == "emit": - # Specifying the axis - # horizontal axis (x) -> 0 - # vertical axis (y) -> 1 - # longitudinal axis (s) -> 2 - - emit_axis = {"x":0, "y":1, "s":2} + option_dict = {"x":0, "y":1, "s":2} - y_var = file[group][dataset][emit_axis[option]]*1e9 + y_var = file[group][dataset][option_dict[option]]*1e9 if option == "x": label = "hor. emittance (nm.rad)" elif option == "y": label = "ver. emittance (nm.rad)" elif option == "s": label = "long. emittance (nm.rad)" - elif dataset == "mean" or "std": - # Specify the variable - # x -> 0 - # xp -> 1 - # y -> 2 - # yp -> 3 - # tau -> 4 - # delta -> 5 - - var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} - scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] - label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", - "$\\tau$ (ps)", "$\\delta$"] + elif dataset == "mean" or "std": + option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] + axis_index = option_dict[option] - y_var = file[group][dataset][var_dict[option]]*scale[var_dict[option]] - label = label_list[var_dict[option]] + y_var = file[group][dataset][axis_index]*scale[axis_index] + if dataset == "mean": + label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)", + "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"] + else: + label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)", + "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", + "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"] - - plt.plot(file[group]["time"][:],y_var) - plt.xlabel("number of turns") - plt.ylabel(label) + label = label_list[axis_index] + + if x_var == "current": + x_axis = file[group]["current"][:] * 1e3 + xlabel = "current (mA)" + elif x_var == "time": + x_axis = file[group]["time"][:] + xlabel = "number of turns" + + fig, ax = plt.subplots() + ax.plot(x_axis,y_var) + ax.set_xlabel(xlabel) + ax.set_ylabel(label) file.close() + return fig -def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, - only_alive=True): +def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, + only_alive=True, plot_size=1): """ - Plot data recorded from a PhaseSpaceMonitor. + Plot data recorded by PhaseSpaceMonitor. Parameters ---------- filename : str Name of the HDF5 file that contains the data. bunch_number : int - Number of the bunch whose data has been saved in the HDF5 file. - This has to be identical to 'bunch_number' parameter in + Bunch to plot. This has to be identical to 'bunch_number' parameter in 'PhaseSpaceMonitor' object. x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"} If dataset is "particles", the variables to be plotted on the @@ -103,6 +209,15 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, only_alive : bool, optional When only_alive is True, only alive particles are plotted and dead particles will be discarded. + plot_size : [0,1], optional + Number of macro-particles to plot relative to the total number + of macro-particles recorded. This option helps reduce processing time + when the data is big. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ file = hp.File(filename, "r") @@ -110,15 +225,7 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, group = "PhaseSpaceData_{0}".format(bunch_number) dataset = "particles" - # Specify the parameter - # x -> 0 - # xp -> 1 - # y -> 2 - # yp -> 3 - # tau -> 4 - # delta -> 5 - - var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} scale = [1e3,1e3,1e3,1e3,1e12,1] label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)", "$\\delta$"] @@ -131,27 +238,139 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, format(turn, file[group]["time"][:])) path = file[group][dataset] - - if only_alive is False: - # format : sns.jointplot(x_axis, yaxis, kind) - x_axis = path[:,var_dict[x_var],turn_index[0][0]] - y_axis = path[:,var_dict[y_var],turn_index[0][0]] - - elif only_alive is True: - alive_index = np.where(file[group]["alive"][:,turn_index]) - - x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]] - y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]] - - sns.jointplot(x_axis*scale[var_dict[x_var]], - y_axis*scale[var_dict[y_var]], kind="kde") + mp_number = path[:,0,0].size - plt.xlabel(label[var_dict[x_var]]) - plt.ylabel(label[var_dict[y_var]]) + if only_alive is True: + index = np.where(file[group]["alive"][:,turn_index])[0] + else: + index = np.arange(mp_number) + + if plot_size == 1: + samples = index + elif plot_size < 1: + samples_meta = random.sample(list(index), int(plot_size*mp_number)) + samples = sorted(samples_meta) + else: + raise ValueError("plot_size must be in range [0,1].") + + # format : sns.jointplot(x_axis, yaxis, kind) + x_axis = path[samples,option_dict[x_var],turn_index[0][0]] + y_axis = path[samples,option_dict[y_var],turn_index[0][0]] + + fig = sns.jointplot(x_axis*scale[option_dict[x_var]], + y_axis*scale[option_dict[y_var]], kind="kde") + + plt.xlabel(label[option_dict[x_var]]) + plt.ylabel(label[option_dict[y_var]]) file.close() + return fig + +def plot_profiledata(filename, bunch_number, dimension="tau", start=0, + stop=None, step=None, profile_plot=True, streak_plot=True): + """ + Plot data recorded by ProfileMonitor + Parameters + ---------- + filename : str + Name of the HDF5 file that contains the data. + bunch_number : int + Bunch to plot. This has to be identical to 'bunch_number' parameter in + 'ProfileMonitor' object. + dimension : str, optional + Dimension to plot. The default is "tau" + start : int, optional + First turn to plot. The default is 0. + stop : int, optional + Last turn to plot. If None, the last turn of the record is selected. + step : int, optional + Plotting step. This has to be divisible by 'save_every' parameter in + 'ProfileMonitor' object, i.e. step % save_every == 0. If None, step is + equivalent to save_every. + profile_plot : bool, optional + If Ture, bunch profile plot is plotted. + streak_plot : bool, optional + If True, strek plot is plotted. + + Returns + ------- + fig : Figure + Figure object with the plot on it. + + """ + + file = hp.File(filename, "r") + path = file['ProfileData_{0}'.format(bunch_number)] + l_bound = path["{0}_bin".format(dimension)] + if stop is None: + stop = path['time'][-1] + elif stop not in path['time']: + raise ValueError("stop not found. Choose from {0}" + .format(path['time'][:])) + + if start not in path['time']: + raise ValueError("start not found. Choose from {0}" + .format(path['time'][:])) - + save_every = path['time'][1] - path['time'][0] + if step is None: + step = save_every + + if step % save_every != 0: + raise ValueError("step must be divisible by the recording step " + "which is {0}.".format(save_every)) + + dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} + scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] + label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)", + "$\\tau$ (ps)", "$\\delta$"] + + num = int((stop - start)/step) + n_bin = len(path[dimension][:,0]) + + start_index = np.where(path['time'][:] == start)[0][0] + + x_var = np.zeros((num+1,n_bin)) + turn_index_array = np.zeros((num+1,)) + for i in range(num+1): + turn_index = start_index + i * step / save_every + turn_index_array[i] = turn_index + # construct an array of bin mids + x_var[i,:] = 0.5*(l_bound[:-1,turn_index]+l_bound[1:,turn_index]) + + if profile_plot is True: + fig, ax = plt.subplots() + for i in range(num+1): + ax.plot(x_var[i]*scale[dimension_dict[dimension]], + path[dimension][:,turn_index_array[i]], + label="turn {0}".format(path['time'][turn_index_array[i]])) + ax.set_xlabel(label[dimension_dict[dimension]]) + ax.set_ylabel("number of macro-particles") + ax.legend() + + if streak_plot is True: + turn = np.reshape(path['time'][turn_index_array], (num+1,1)) + y_var = np.ones((num+1,n_bin)) * turn + z_var = np.transpose(path[dimension][:,turn_index_array]) + fig2, ax2 = plt.subplots() + cmap = mpl.cm.cool + c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto', + extent=[x_var.min()*scale[dimension_dict[dimension]], + x_var.max()*scale[dimension_dict[dimension]], + y_var.min(),y_var.max()]) + ax2.set_xlabel(label[dimension_dict[dimension]]) + ax2.set_ylabel("Number of turns") + cbar = fig2.colorbar(c, ax=ax2) + cbar.set_label("Number of macro-particles") + + file.close() + if profile_plot is True and streak_plot is True: + return fig, fig2 + elif profile_plot is True: + return fig + elif streak_plot is True: + return fig2 + \ No newline at end of file diff --git a/tracking/optics.py b/tracking/optics.py index 2e3483657181ebab0073f5d9eb135ac7808d07dd..da6ae20aec203e9828cf92d9066fda9f79345357 100644 --- a/tracking/optics.py +++ b/tracking/optics.py @@ -55,6 +55,8 @@ class Optics: Return gamma functions at specific locations given by position. dispersion(position) Return dispersion functions at specific locations given by position. + plot(self, var, option, n_points=1000) + Plot optical variables. """ def __init__(self, lattice_file=None, local_beta=None, local_alpha=None, @@ -244,6 +246,66 @@ class Optics: dispersion = [self.dispX(position), self.disppX(position), self.dispY(position), self.disppY(position)] return np.array(dispersion) + + def plot(self, var, option, n_points=1000): + """ + Plot optical variables. + + Parameters + ---------- + var : {"beta", "alpha", "gamma", "dispersion"} + Optical variable to be plotted. + option : str + If var = "beta", "alpha" and "gamma", option = {"x","y"} specifying + the axis of interest. + If var = "dispersion", option = {"x","px","y","py"} specifying the + axis of interest for the dispersion function or its derivative. + n_points : int + Number of points on the plot. The default value is 1000. + + """ + + var_dict = {"beta":self.beta, "alpha":self.alpha, "gamma":self.gamma, + "dispersion":self.dispersion} + + if var == "dispersion": + option_dict = {"x":0, "px":1, "y":2, "py":3} + + label = ["D$_{x}$ (m)", "D'$_{x}$", "D$_{y}$ (m)", "D'$_{y}$"] + + ylabel = label[option_dict[option]] + + + elif var=="beta" or var=="alpha" or var=="gamma": + option_dict = {"x":0, "y":1} + label_dict = {"beta":"$\\beta$", "alpha":"$\\alpha$", + "gamma":"$\\gamma$"} + + if option == "x": label_sup = "$_{x}$" + elif option == "y": label_sup = "$_{y}$" + + unit = {"beta":" (m)", "alpha":"", "gamma":" (m$^{-1}$)"} + + ylabel = label_dict[var] + label_sup + unit[var] + + + else: + raise ValueError("Variable name is not found.") + + if self.use_local_values is not True: + position = np.linspace(0, self.lattice.circumference, int(n_points)) + else: + position = np.linspace(0,1) + + var_list = var_dict[var](position)[option_dict[option]] + fig, ax = plt.subplots() + ax.plot(position,var_list) + + ax.set_xlabel("position (m)") + ax.set_ylabel(ylabel) + + return fig + class PhyisicalModel: """ @@ -514,3 +576,4 @@ class PhyisicalModel: axs[1].set(xlabel="Longitudinal position [m]", ylabel="Vertical aperture [mm]") axs[1].legend(["Top","Bottom"]) + diff --git a/tracking/particles.py b/tracking/particles.py index d3e7f171ab827c2aeec740a19c7cc17c940ee47d..7a97fa35b64ca960442bcb856103b0e4c3ab8722 100644 --- a/tracking/particles.py +++ b/tracking/particles.py @@ -336,26 +336,35 @@ class Bunch: plot_type : str {"j" , "sc"} Type of the plot. The defualt value is "j" for a joint plot. Can be modified to "sc" for a scatter plot. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"} scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1} + if plot_type == "sc": - plt.scatter(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var]) - plt.xlabel(label_dict[x_var]) - plt.ylabel(label_dict[y_var]) + fig, ax = plt.subplots() + ax.scatter(self.particles[x_var]*scale[x_var], + self.particles[y_var]*scale[y_var]) + ax.set_xlabel(label_dict[x_var]) + ax.set_ylabel(label_dict[y_var]) elif plot_type == "j": - sns.jointplot(self.particles[x_var]*scale[x_var], - self.particles[y_var]*scale[y_var],kind="kde") + fig = sns.jointplot(self.particles[x_var]*scale[x_var], + self.particles[y_var]*scale[y_var],kind="kde") plt.xlabel(label_dict[x_var]) plt.ylabel(label_dict[y_var]) else: raise ValueError("Plot type not recognised.") + + return fig class Beam: """ @@ -630,6 +639,11 @@ class Beam: option = {"x","xp","y","yp","tau","delta"}. For "bunch_emit", option = {"x","y","s"}. The default is None. + + Return + ------ + fig : Figure + Figure object with the plot on it. """ var_dict = {"bunch_current":self.bunch_current, @@ -639,6 +653,8 @@ class Beam: "bunch_std":self.bunch_std, "bunch_emit":self.bunch_emit} + fig, ax= plt.subplots() + if var == "bunch_mean" or var == "bunch_std": value_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1] @@ -654,13 +670,13 @@ class Beam: where_is_nan = np.isnan(y_axis) y_axis[where_is_nan] = 0 - plt.plot(np.arange(len(self.filling_pattern)), + ax.plot(np.arange(len(self.filling_pattern)), y_axis*scale[value_dict[option]]) - plt.xlabel('bunch number') + ax.set_xlabel('bunch number') if var == "bunch_mean": - plt.ylabel(label_mean[value_dict[option]]) + ax.set_ylabel(label_mean[value_dict[option]]) else: - plt.ylabel(label_std[value_dict[option]]) + ax.set_ylabel(label_std[value_dict[option]]) elif var == "bunch_emit": value_dict = {"x":0, "y":1, "s":2} @@ -672,33 +688,34 @@ class Beam: where_is_nan = np.isnan(y_axis) y_axis[where_is_nan] = 0 - plt.plot(np.arange(len(self.filling_pattern)), + ax.plot(np.arange(len(self.filling_pattern)), y_axis*scale[value_dict[option]]) if option == "x": label_y = "hor. emittance (nm.rad)" elif option == "y": label_y = "ver. emittance (nm.rad)" elif option == "s": label_y = "long. emittance (fm.rad)" - plt.xlabel('bunch number') - plt.ylabel(label_y) + ax.set_xlabel('bunch number') + ax.set_ylabel(label_y) elif var=="bunch_current" or var=="bunch_charge" or var=="bunch_particle": scale = {"bunch_current":1e3, "bunch_charge":1e9, "bunch_particle":1} - plt.plot(np.arange(len(self.filling_pattern)), var_dict[var]* + ax.plot(np.arange(len(self.filling_pattern)), var_dict[var]* scale[var]) - plt.xlabel('bunch number') + ax.set_xlabel('bunch number') if var == "bunch_current": label_y = "bunch current (mA)" elif var == "bunch_charge": label_y = "bunch chagre (nC)" else: label_y = "number of particles" - plt.ylabel(label_y) + ax.set_ylabel(label_y) elif var == "current" or var=="charge" or var=="particle_number": - print("'{0}'is a total value and cannot be plotted".format(var)) + raise ValueError("'{0}'is a total value and cannot be plotted." + .format(var)) - + return fig