From 9399d5e26030f0ab910314cf1ac39f9a85b60950 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 17 Nov 2021 18:34:14 +0100
Subject: [PATCH] Add methods to check and modify sampling for WakePotential
 class

---
 tracking/wakepotential.py | 36 ++++++++++++++++++++++++++++++++++++
 1 file changed, 36 insertions(+)

diff --git a/tracking/wakepotential.py b/tracking/wakepotential.py
index b22b173..3adee16 100644
--- a/tracking/wakepotential.py
+++ b/tracking/wakepotential.py
@@ -65,6 +65,10 @@ class WakePotential(Element):
         Calculate the loss factor and kick factor from the wake potential and 
         compare it to a reference value assuming a Gaussian bunch computed in 
         the frequency domain.
+    check_sampling()
+        Check if the wake function sampling is uniform.
+    reduce_sampling(factor)
+        Reduce wake function samping by an integer factor.
         
     """
     
@@ -74,6 +78,7 @@ class WakePotential(Element):
         self.n_types = len(self.wakefield.wake_components)
         self.ring = ring
         self.n_bin = n_bin
+        self.check_sampling()
             
     def charge_density(self, bunch):
         """
@@ -527,7 +532,38 @@ class WakePotential(Element):
                                  columns=column, 
                                  index=index)
         return loss_data
+
+    def check_sampling(self):
+        """
+        Check if the wake function sampling is uniform.
+
+        Raises
+        ------
+        ValueError
+
+        """
+        for wake_type in self.types:
+            idx = getattr(self.wakefield, wake_type).data.index
+            diff = idx[1:]-idx[:-1]
+            result = np.all(np.isclose(diff, diff[0], atol=1e-15))
+            if result is False:
+                raise ValueError("The wake function must be uniformly sampled.")
     
+    def reduce_sampling(self, factor):
+        """
+        Reduce wake function samping by an integer factor.
+        
+        Used to reduce computation time for long bunches.
+
+        Parameters
+        ----------
+        factor : int
+
+        """
+        for wake_type in self.types:
+            idx = getattr(self.wakefield, wake_type).data.index[::factor]
+            getattr(self.wakefield, wake_type).data = getattr(self.wakefield, wake_type).data.loc[idx]
+        self.check_sampling()
     
     
     
-- 
GitLab