From 060ec29e64dcbeecabcf267f3814bd896c615ede Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Tue, 30 Jan 2024 17:17:51 +0100
Subject: [PATCH] Add interp_on_postion option for WakePotential class

interp_on_postion=False allows for faster tracking but the wake potential is interpolated on the bin center and each particle of the bin get the same value.
---
 mbtrack2/tracking/wakepotential.py | 17 ++++++++++++++---
 1 file changed, 14 insertions(+), 3 deletions(-)

diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py
index d64865c..46abfbd 100644
--- a/mbtrack2/tracking/wakepotential.py
+++ b/mbtrack2/tracking/wakepotential.py
@@ -33,6 +33,11 @@ class WakePotential(Element):
         functions must be uniformly sampled!
     n_bin : int, optional
         Number of bins for constructing the longitudinal bunch profile.
+    interp_on_postion : bool, optional
+        If True, the computed wake potential is interpolated on the exact 
+        particle location. If False, the wake potential is interpolated on the 
+        bin center and each particle of the bin get the same value.
+        Default is True.
         
     Attributes
     ----------
@@ -71,13 +76,14 @@ class WakePotential(Element):
         Reduce wake function samping by an integer factor.
         
     """
-    def __init__(self, ring, wakefield, n_bin=80):
+    def __init__(self, ring, wakefield, n_bin=80, interp_on_postion=True):
         self.wakefield = wakefield
         self.types = self.wakefield.wake_components
         self.n_types = len(self.wakefield.wake_components)
         self.ring = ring
         self.n_bin = n_bin
         self.check_sampling()
+        self.interp_on_postion = interp_on_postion
 
         # Suppress numpy warning for floating-point operations.
         np.seterr(invalid='ignore')
@@ -296,8 +302,13 @@ class WakePotential(Element):
             self.charge_density(bunch)
             for wake_type in self.types:
                 tau0, Wp = self.get_wakepotential(bunch, wake_type)
-                Wp_interp = np.interp(bunch["tau"], tau0 + self.tau_mean, Wp,
-                                      0, 0)
+                if self.interp_on_postion:
+                    Wp_interp = np.interp(bunch["tau"], tau0 + self.tau_mean, Wp,
+                                          0, 0)
+                else:
+                    Wp_interp = np.interp(self.center, tau0 + self.tau_mean, Wp,
+                                          0, 0)
+                    Wp_interp= Wp_interp[self.sorted_index]
                 if wake_type == "Wlong":
                     bunch["delta"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wxdip":
-- 
GitLab