Skip to content
Snippets Groups Projects
Commit ad4e7890 authored by Alexis GAMELIN's avatar Alexis GAMELIN
Browse files

Merge branch 'PhaseSpaceMonitor' into 'develop'

[Fix] Issue with random sampling in PhaseSpaceMonitor

See merge request !4
parents 033c530d c847d125
No related branches found
No related tags found
2 merge requests!13v0.7.0,!4[Fix] Issue with random sampling in PhaseSpaceMonitor
......@@ -352,9 +352,10 @@ class PhaseSpaceMonitor(Monitor):
----------
bunch_number : int
Bunch to monitor
mp_number : int or float
sample_number : int or float
Number of macroparticle in the phase space to save. If less than the
total number of macroparticles, a random fraction of the bunch is saved.
total number of macroparticles, the total number must be specified in
optinal parameter mp_number and a random fraction of the bunch is saved.
save_every : int or float
Set the frequency of the save. The data is saved every save_every
call of the montior.
......@@ -373,6 +374,10 @@ class PhaseSpaceMonitor(Monitor):
If True, open the HDF5 file in parallel mode, which is needed to
allow several cores to write in the same file at the same time.
If False, open the HDF5 file in standard mode.
mp_number : int or float, optional
Total number of macroparticle of the tracked bunch.
Mandatory if sample_number != mp_number.
Default is None.
Methods
-------
......@@ -382,29 +387,41 @@ class PhaseSpaceMonitor(Monitor):
def __init__(self,
bunch_number,
mp_number,
sample_number,
save_every,
buffer_size,
total_size,
file_name=None,
mpi_mode=False):
mpi_mode=False,
mp_number=None):
self.bunch_number = bunch_number
self.mp_number = int(mp_number)
self.sample_number = int(sample_number)
if mp_number is None:
self.mp_number = self.sample_number
else:
self.mp_number = int(mp_number)
group_name = "PhaseSpaceData_" + str(self.bunch_number)
dict_buffer = {
"particles": (self.mp_number, 6, buffer_size),
"alive": (self.mp_number, buffer_size)
"particles": (self.sample_number, 6, buffer_size),
"alive": (self.sample_number, buffer_size)
}
dict_file = {
"particles": (self.mp_number, 6, total_size),
"alive": (self.mp_number, total_size)
"particles": (self.sample_number, 6, total_size),
"alive": (self.sample_number, total_size)
}
self.monitor_init(group_name, save_every, buffer_size, total_size,
dict_buffer, dict_file, file_name, mpi_mode)
self.dict_buffer = dict_buffer
self.dict_file = dict_file
if self.sample_number != self.mp_number:
index = np.arange(self.mp_number)
samples_meta = random.sample(list(index), self.sample_number)
self.samples = sorted(samples_meta)
else:
self.samples = slice(None)
def track(self, object_to_save):
"""
......@@ -426,17 +443,10 @@ class PhaseSpaceMonitor(Monitor):
"""
self.time[self.buffer_count] = self.track_count
if len(bunch.alive) != self.mp_number:
index = np.arange(len(bunch.alive))
samples_meta = random.sample(list(index), self.mp_number)
samples = sorted(samples_meta)
else:
samples = slice(None)
self.alive[:, self.buffer_count] = bunch.alive[samples]
self.alive[:, self.buffer_count] = bunch.alive[self.samples]
for i, dim in enumerate(bunch):
self.particles[:, i,
self.buffer_count] = bunch.particles[dim][samples]
self.buffer_count] = bunch.particles[dim][self.samples]
self.buffer_count += 1
......
......@@ -401,7 +401,7 @@ def plot_phasespacedata(filename,
turn,
only_alive=True,
plot_size=1,
plot_kind='kde'):
plot_kind='scatter'):
"""
Plot data recorded by PhaseSpaceMonitor.
......@@ -425,7 +425,7 @@ def plot_phasespacedata(filename,
of macro-particles recorded. This option helps reduce processing time
when the data is big.
plot_kind : {'scatter', 'kde', 'hex', 'reg', 'resid'}, optional
The plot style. The default is 'kde'.
The plot style. The default is 'scatter'.
Return
------
......@@ -473,8 +473,8 @@ def plot_phasespacedata(filename,
x_axis = path[samples, option_dict[x_var], turn_index[0][0]]
y_axis = path[samples, option_dict[y_var], turn_index[0][0]]
fig = sns.jointplot(x_axis * scale[option_dict[x_var]],
y_axis * scale[option_dict[y_var]],
fig = sns.jointplot(x=x_axis * scale[option_dict[x_var]],
y=y_axis * scale[option_dict[y_var]],
kind=plot_kind)
plt.xlabel(label[option_dict[x_var]])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment