Skip to content
Snippets Groups Projects
Commit dc83a72d authored by Watanyu Foosang's avatar Watanyu Foosang
Browse files

Adding plotting methods in Beam class

Two plotting methods, "plot_bunchdata" and "plot_phasespacedata", have been added to Beam class.
"getplot" method in Bunch class has been modified and renamed to "long_phasespace".
parent 89fbf885
No related branches found
No related tags found
No related merge requests found
...@@ -10,9 +10,10 @@ Module where particles, bunches and beams are described as objects. ...@@ -10,9 +10,10 @@ Module where particles, bunches and beams are described as objects.
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from stage.mbtrack2.tracking.parallel import Mpi from tracking.parallel import Mpi
from scipy.constants import c, m_e, m_p, e from scipy.constants import c, m_e, m_p, e
import seaborn as sns import seaborn as sns
import h5py as hp
class Particle: class Particle:
""" Define a particle object """ Define a particle object
...@@ -273,25 +274,32 @@ class Bunch: ...@@ -273,25 +274,32 @@ class Bunch:
self.particles["tau"] = values[:,4] self.particles["tau"] = values[:,4]
self.particles["delta"] = values[:,5] self.particles["delta"] = values[:,5]
def getplot(self,x,y,p_type): def long_phasespace(self,x_var="tau",y_var="delta",plot_type="j"):
""" """
Plot the parameters from particles object. Plot longitudinal phase space.
Parameters Parameters
---------- ----------
x: str, name from Bunch object to plot on hor. axis. x_var : str
y: str, name from Bunch object to plot on ver. axis. name from Bunch object to plot on horizontal axis.
p_type: str, "sc" for a scatter plot or "j" for a joint plot. y_var : str
name from Bunch object to plot on ver. axis.
plot_type : str {"j" , "sc"}
type of the plot. The defualt value is "j" for a joint plot.
Can be modified to "sc" for a scatter plot.
""" """
x = self.particles[x]
y = self.particles[y]
if p_type == "sc": if plot_type == "sc":
plt.scatter(x,y) plt.scatter(self.particles["tau"]*1e12,
self.particles["delta"])
plt.xlabel("$\\tau$ (ps)")
plt.ylabel("$\\delta$")
elif p_type == "j": else:
sns.jointplot(x,y,kind="kde") sns.jointplot(self.particles["tau"]*1e12,
self.particles["delta"],kind="kde")
plt.xlabel("$\\tau$ (ps)")
plt.ylabel("$\\delta$")
class Beam: class Beam:
...@@ -528,7 +536,164 @@ class Beam: ...@@ -528,7 +536,164 @@ class Beam:
self.mpi_switch = False self.mpi_switch = False
self.mpi = None self.mpi = None
def long_phasespace(self,bunch_number,x_var="tau",y_var="delta",
plot_type="j"):
"""
Plot longitudinal phase space.
Parameters
----------
bunch_number : int
Specify a bunch among those in beam object to be displayed.
The value must not exceed the total length of filling_pattern object.
x_var : str
name from Bunch object to plot on horizontal axis.
y_var : str
name from Bunch object to plot on ver. axis.
plot_type : str {"j" , "sc"}
type of the plot. The defualt value is "j" for a joint plot.
Can be modified to "sc" for a scatter plot.
"""
if plot_type == "sc":
plt.scatter(self[bunch_number-1]["tau"]*1e12,
self[bunch_number-1]["delta"])
plt.xlabel("$\\tau$ (ps)")
plt.ylabel("$\\delta$")
else:
sns.jointplot(self[bunch_number-1]["tau"]*1e12,
self[bunch_number-1]["delta"],kind="kde")
plt.xlabel("$\\tau$ (ps)")
plt.ylabel("$\\delta$")
def plot_bunchdata(self, filename, detaset, x_var="time"):
"""
Plot the evolution of the variables from the Beam object.
Parameters
----------
filename : str
Name of the HDF5 file that contains the data.
detaset : str {"current","emit","mean","std"}
HDF5 file's dataset to be plotted.
x_var : str, optional
The variable to be plotted on horizontal axis. The default is "time".
"""
file = hp.File(filename, "r")
group = "BunchData_0" # Data group of the HDF5 file
if detaset == "current":
y_var = file[group][detaset][:]*1e3
label = "current (mA)"
elif detaset == "emit":
axis = int(input("""Specify the axis by entering the corresponding number :
horizontal axis (x) -> 0
vertical axis (y) -> 1
longitudinal axis (s) -> 2
: """))
y_var = file[group][detaset][axis]*1e9
if axis == 0: label = "hor. emittance (nm.rad)"
elif axis == 1: label = "ver. emittance (nm.rad)"
elif axis == 2: label = "long. emittance (nm.rad)"
elif detaset == "mean" or "std":
param = int(input("""Specify the variable by entering the corresponding number: \
x -> 0 \
xp -> 1 \
y -> 2 \
yp -> 3 \
tau -> 4 \
delta -> 5 \
:"""))
if param == 0:
y_var = file[group][detaset][param]*1e6
label = "x (um)"
elif param == 1:
y_var = file[group][detaset][param]*1e6
label = "x' ($\\mu$rad)"
elif param == 2:
y_var = file[group][detaset][param]*1e6
label = "y (um)"
elif param == 3:
y_var = file[group][detaset][param]*1e6
label = "y' ($\\mu$rad)"
elif param == 4:
y_var = file[group][detaset][param]*1e12
label = "$\\tau$ (ps)"
else :
y_var = file[group][detaset][param]
label = "$\\delta$"
plt.plot(file[group]["time"][:],y_var)
plt.xlabel("number of turns")
plt.ylabel(label)
def plot_phasespacedata(self, filename, total_size, save_every, dataset):
"""
Plot data from PhaseSpaceData_0 group of the HDF5 file.
Parameters
----------
filename : str
Name of the HDF5 file that contains the data.
total_size : int
Total size of the save regarding to 'PhaseSpaceMonitor' object.
save_every : int
The frequency of the save regarding to 'PhaseSpaceMonitor' object.
dataset : str {'alive', 'partcicles'}
HDF5 file's dataset to be plotted.
"""
file = hp.File(filename, "r")
data_points = int(total_size/save_every)
timelist = file["PhaseSpaceData_0"]["time"][0:data_points]
if dataset == "alive":
alive_at_a_time = []
for i in range (data_points):
alive_at_a_time.append(np.sum(file["PhaseSpaceData_0"][dataset][:,i]))
plt.plot(timelist,alive_at_a_time)
plt.xlabel("number of turns")
plt.ylabel("number of alive particles")
elif dataset == "particles":
print("""Specify the parameter to be plotted by entering the corresponding number:
x -> 0
xp -> 1
y -> 2
yp -> 3
tau -> 4
delta -> 5 """)
x_var = int(input("Horizontal axis: "))
y_var = int(input("Vertical axis: "))
print("Specify the time at turn =", timelist)
turn = int(input(": "))
index_in_timelist = np.where(timelist==turn)
sns.jointplot(file["PhaseSpaceData_0"][dataset][:,x_var,index_in_timelist],
file["PhaseSpaceData_0"][dataset][:,y_var,index_in_timelist],
kind="kde")
name_list = ["x (m)","xp (rad)","y (m)","yp (rad)","$\\tau$ (s)","$\\delta$"]
plt.xlabel(name_list[x_var])
plt.ylabel(name_list[y_var])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment