From c847d125232282cce8afd620732e754f8205d974 Mon Sep 17 00:00:00 2001 From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr> Date: Wed, 5 Jun 2024 17:51:53 +0200 Subject: [PATCH] [Fix] Issue with random sampling in PhaseSpaceMonitor Fix plotting in plot_phasespacedata --- mbtrack2/tracking/monitors/monitors.py | 46 ++++++++++++++++---------- mbtrack2/tracking/monitors/plotting.py | 8 ++--- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/mbtrack2/tracking/monitors/monitors.py b/mbtrack2/tracking/monitors/monitors.py index dfa9f54..4e87bf2 100644 --- a/mbtrack2/tracking/monitors/monitors.py +++ b/mbtrack2/tracking/monitors/monitors.py @@ -352,9 +352,10 @@ class PhaseSpaceMonitor(Monitor): ---------- bunch_number : int Bunch to monitor - mp_number : int or float + sample_number : int or float Number of macroparticle in the phase space to save. If less than the - total number of macroparticles, a random fraction of the bunch is saved. + total number of macroparticles, the total number must be specified in + optinal parameter mp_number and a random fraction of the bunch is saved. save_every : int or float Set the frequency of the save. The data is saved every save_every call of the montior. @@ -373,6 +374,10 @@ class PhaseSpaceMonitor(Monitor): If True, open the HDF5 file in parallel mode, which is needed to allow several cores to write in the same file at the same time. If False, open the HDF5 file in standard mode. + mp_number : int or float, optional + Total number of macroparticle of the tracked bunch. + Mandatory if sample_number != mp_number. + Default is None. Methods ------- @@ -382,29 +387,41 @@ class PhaseSpaceMonitor(Monitor): def __init__(self, bunch_number, - mp_number, + sample_number, save_every, buffer_size, total_size, file_name=None, - mpi_mode=False): + mpi_mode=False, + mp_number=None): self.bunch_number = bunch_number - self.mp_number = int(mp_number) + self.sample_number = int(sample_number) + if mp_number is None: + self.mp_number = self.sample_number + else: + self.mp_number = int(mp_number) group_name = "PhaseSpaceData_" + str(self.bunch_number) dict_buffer = { - "particles": (self.mp_number, 6, buffer_size), - "alive": (self.mp_number, buffer_size) + "particles": (self.sample_number, 6, buffer_size), + "alive": (self.sample_number, buffer_size) } dict_file = { - "particles": (self.mp_number, 6, total_size), - "alive": (self.mp_number, total_size) + "particles": (self.sample_number, 6, total_size), + "alive": (self.sample_number, total_size) } self.monitor_init(group_name, save_every, buffer_size, total_size, dict_buffer, dict_file, file_name, mpi_mode) self.dict_buffer = dict_buffer self.dict_file = dict_file + + if self.sample_number != self.mp_number: + index = np.arange(self.mp_number) + samples_meta = random.sample(list(index), self.sample_number) + self.samples = sorted(samples_meta) + else: + self.samples = slice(None) def track(self, object_to_save): """ @@ -426,17 +443,10 @@ class PhaseSpaceMonitor(Monitor): """ self.time[self.buffer_count] = self.track_count - if len(bunch.alive) != self.mp_number: - index = np.arange(len(bunch.alive)) - samples_meta = random.sample(list(index), self.mp_number) - samples = sorted(samples_meta) - else: - samples = slice(None) - - self.alive[:, self.buffer_count] = bunch.alive[samples] + self.alive[:, self.buffer_count] = bunch.alive[self.samples] for i, dim in enumerate(bunch): self.particles[:, i, - self.buffer_count] = bunch.particles[dim][samples] + self.buffer_count] = bunch.particles[dim][self.samples] self.buffer_count += 1 diff --git a/mbtrack2/tracking/monitors/plotting.py b/mbtrack2/tracking/monitors/plotting.py index a02842b..9e1f56b 100644 --- a/mbtrack2/tracking/monitors/plotting.py +++ b/mbtrack2/tracking/monitors/plotting.py @@ -401,7 +401,7 @@ def plot_phasespacedata(filename, turn, only_alive=True, plot_size=1, - plot_kind='kde'): + plot_kind='scatter'): """ Plot data recorded by PhaseSpaceMonitor. @@ -425,7 +425,7 @@ def plot_phasespacedata(filename, of macro-particles recorded. This option helps reduce processing time when the data is big. plot_kind : {'scatter', 'kde', 'hex', 'reg', 'resid'}, optional - The plot style. The default is 'kde'. + The plot style. The default is 'scatter'. Return ------ @@ -473,8 +473,8 @@ def plot_phasespacedata(filename, x_axis = path[samples, option_dict[x_var], turn_index[0][0]] y_axis = path[samples, option_dict[y_var], turn_index[0][0]] - fig = sns.jointplot(x_axis * scale[option_dict[x_var]], - y_axis * scale[option_dict[y_var]], + fig = sns.jointplot(x=x_axis * scale[option_dict[x_var]], + y=y_axis * scale[option_dict[y_var]], kind=plot_kind) plt.xlabel(label[option_dict[x_var]]) -- GitLab