Skip to content
Snippets Groups Projects

[Fix] Issue with random sampling in PhaseSpaceMonitor

Merged Alexis GAMELIN requested to merge PhaseSpaceMonitor into develop
2 files
+ 32
22
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -352,9 +352,10 @@ class PhaseSpaceMonitor(Monitor):
----------
bunch_number : int
Bunch to monitor
mp_number : int or float
sample_number : int or float
Number of macroparticle in the phase space to save. If less than the
total number of macroparticles, a random fraction of the bunch is saved.
total number of macroparticles, the total number must be specified in
optinal parameter mp_number and a random fraction of the bunch is saved.
save_every : int or float
Set the frequency of the save. The data is saved every save_every
call of the montior.
@@ -373,6 +374,10 @@ class PhaseSpaceMonitor(Monitor):
If True, open the HDF5 file in parallel mode, which is needed to
allow several cores to write in the same file at the same time.
If False, open the HDF5 file in standard mode.
mp_number : int or float, optional
Total number of macroparticle of the tracked bunch.
Mandatory if sample_number != mp_number.
Default is None.
Methods
-------
@@ -382,29 +387,41 @@ class PhaseSpaceMonitor(Monitor):
def __init__(self,
bunch_number,
mp_number,
sample_number,
save_every,
buffer_size,
total_size,
file_name=None,
mpi_mode=False):
mpi_mode=False,
mp_number=None):
self.bunch_number = bunch_number
self.mp_number = int(mp_number)
self.sample_number = int(sample_number)
if mp_number is None:
self.mp_number = self.sample_number
else:
self.mp_number = int(mp_number)
group_name = "PhaseSpaceData_" + str(self.bunch_number)
dict_buffer = {
"particles": (self.mp_number, 6, buffer_size),
"alive": (self.mp_number, buffer_size)
"particles": (self.sample_number, 6, buffer_size),
"alive": (self.sample_number, buffer_size)
}
dict_file = {
"particles": (self.mp_number, 6, total_size),
"alive": (self.mp_number, total_size)
"particles": (self.sample_number, 6, total_size),
"alive": (self.sample_number, total_size)
}
self.monitor_init(group_name, save_every, buffer_size, total_size,
dict_buffer, dict_file, file_name, mpi_mode)
self.dict_buffer = dict_buffer
self.dict_file = dict_file
if self.sample_number != self.mp_number:
index = np.arange(self.mp_number)
samples_meta = random.sample(list(index), self.sample_number)
self.samples = sorted(samples_meta)
else:
self.samples = slice(None)
def track(self, object_to_save):
"""
@@ -426,17 +443,10 @@ class PhaseSpaceMonitor(Monitor):
"""
self.time[self.buffer_count] = self.track_count
if len(bunch.alive) != self.mp_number:
index = np.arange(len(bunch.alive))
samples_meta = random.sample(list(index), self.mp_number)
samples = sorted(samples_meta)
else:
samples = slice(None)
self.alive[:, self.buffer_count] = bunch.alive[samples]
self.alive[:, self.buffer_count] = bunch.alive[self.samples]
for i, dim in enumerate(bunch):
self.particles[:, i,
self.buffer_count] = bunch.particles[dim][samples]
self.buffer_count] = bunch.particles[dim][self.samples]
self.buffer_count += 1
Loading