From 8010a120c71289660df450659628fc4100b14be3 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Sun, 15 Nov 2020 10:53:18 +0100
Subject: [PATCH] Bugfix : PhaseSpaceMonitor

Bugfix due to switch to numpy arrays.
Allow to save less mp than the full number in the bunch.

Bugfix2 : PhaseSpaceMonitor
---
 tracking/monitors/monitors.py | 31 ++++++++++++++++++++++++++++++-
 1 file changed, 30 insertions(+), 1 deletion(-)

diff --git a/tracking/monitors/monitors.py b/tracking/monitors/monitors.py
index c82a2f4..cf139f7 100644
--- a/tracking/monitors/monitors.py
+++ b/tracking/monitors/monitors.py
@@ -295,7 +295,8 @@ class PhaseSpaceMonitor(Monitor):
     bunch_number : int
         Bunch to monitor
     mp_number : int or float
-        Number of macroparticle in the phase space to save.
+        Number of macroparticle in the phase space to save. If less than the 
+        total number of macroparticles, a random fraction of the bunch is saved.
     file_name : string, optional
         Name of the HDF5 where the data will be stored. Must be specified
         the first time a subclass of Monitor is instancied and must be None
@@ -346,6 +347,34 @@ class PhaseSpaceMonitor(Monitor):
         object_to_save : Bunch or Beam object
         """        
         self.track_bunch_data(object_to_save)
+        
+    def to_buffer(self, bunch):
+        """
+        Save data to buffer.
+        
+        Parameters
+        ----------
+        bunch : Bunch object
+        """
+        self.time[self.buffer_count] = self.track_count
+        
+        if len(bunch.alive) != self.mp_number:
+            index = np.arange(len(bunch.alive))
+            samples_meta = random.sample(list(index), self.mp_number)
+            samples = sorted(samples_meta)
+        else:
+            samples = slice(None)
+
+        self.alive[:, self.buffer_count] = bunch.alive[samples]
+        for i, dim in enumerate(bunch):
+            self.particles[:, i, self.buffer_count] = bunch.particles[dim][samples]
+        
+        self.buffer_count += 1
+        
+        if self.buffer_count == self.buffer_size:
+            self.write()
+            self.buffer_count = 0
+        
 
             
 class BeamMonitor(Monitor):
-- 
GitLab