From c7f2a92b37eb672bd54b6fae4540f412aa944b05 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Fri, 22 Apr 2022 17:33:05 +0200
Subject: [PATCH] Add Courant Snyder invariant to BeamMonitor

Add Courant Snyder invariant monitoring in BeamMonitor
Add a bunch_cs method to Beam
---
 mbtrack2/tracking/monitors/monitors.py | 16 ++++++++++++++--
 mbtrack2/tracking/particles.py         | 10 ++++++++++
 2 files changed, 24 insertions(+), 2 deletions(-)

diff --git a/mbtrack2/tracking/monitors/monitors.py b/mbtrack2/tracking/monitors/monitors.py
index 6d01794..e494497 100644
--- a/mbtrack2/tracking/monitors/monitors.py
+++ b/mbtrack2/tracking/monitors/monitors.py
@@ -441,11 +441,13 @@ class BeamMonitor(Monitor):
         dict_buffer = {"mean" : (6, h, buffer_size), 
                        "std" : (6, h, buffer_size),
                        "emit" : (3, h, buffer_size),
-                       "current" : (h, buffer_size)}
+                       "current" : (h, buffer_size),
+                       "cs_invariant" : (2, h, buffer_size)}
         dict_file = {"mean" : (6, h, total_size), 
                        "std" : (6, h, total_size),
                        "emit" : (3, h, total_size),
-                       "current" : (h, total_size)}
+                       "current" : (h, total_size),
+                       "cs_invariant" : (2, h, total_size)}
         
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
@@ -481,6 +483,7 @@ class BeamMonitor(Monitor):
         self.std[:, bunch_num, self.buffer_count] = bunch.std
         self.emit[:, bunch_num, self.buffer_count] = bunch.emit
         self.current[bunch_num, self.buffer_count] = bunch.current
+        self.cs_invariant[:, bunch_num, self.buffer_count] = bunch.cs_invariant
         
         self.buffer_count += 1
         
@@ -502,6 +505,7 @@ class BeamMonitor(Monitor):
         self.std[:, :, self.buffer_count] = beam.bunch_std
         self.emit[:, :, self.buffer_count] = beam.bunch_emit
         self.current[:, self.buffer_count] = beam.bunch_current
+        self.cs_invariant[:, :, self.buffer_count] = beam.bunch_cs
         
         self.buffer_count += 1
         
@@ -535,6 +539,10 @@ class BeamMonitor(Monitor):
         self.file[self.group_name]["current"][bunch_num, 
                  self.write_count*self.buffer_size:(self.write_count+1) * 
                  self.buffer_size] = self.current[bunch_num, :]
+        
+        self.file[self.group_name]["cs_invariant"][:, bunch_num, 
+                 self.write_count*self.buffer_size:(self.write_count+1) * 
+                 self.buffer_size] = self.cs_invariant[:, bunch_num, :]
                  
         self.file.flush() 
         self.write_count += 1
@@ -562,6 +570,10 @@ class BeamMonitor(Monitor):
         self.file[self.group_name]["current"][:, 
                  self.write_count*self.buffer_size:(self.write_count+1) * 
                  self.buffer_size] = self.current
+        
+        self.file[self.group_name]["cs_invariant"][:, :, 
+                 self.write_count*self.buffer_size:(self.write_count+1) * 
+                 self.buffer_size] = self.cs_invariant
                  
         self.file.flush() 
         self.write_count += 1
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index b86b409..49b9adf 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -679,6 +679,16 @@ class Beam:
             bunch_emit[:,index] = bunch.emit
         return bunch_emit
     
+    @property
+    def bunch_cs(self):
+        """Return an array with the average Courant-Snyder invariant for each 
+        bunch"""
+        bunch_cs = np.zeros((2,self.ring.h))
+        for idx, bunch in enumerate(self.not_empty):
+            index = self.bunch_index[idx]
+            bunch_cs[:,index] = bunch.cs_invariant
+        return bunch_cs
+    
     def mpi_init(self):
         """Switch on MPI parallelisation and initialise a Mpi object"""
         from mbtrack2.tracking.parallel import Mpi
-- 
GitLab