From 190c19575e1597f2fc4658923a33d3f4d23235ec Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Fri, 16 Sep 2022 14:22:26 +0200
Subject: [PATCH] Performance improvement for WakePotential

Switch from interp1d to np.interp
---
 mbtrack2/tracking/wakepotential.py | 12 +++---------
 1 file changed, 3 insertions(+), 9 deletions(-)

diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py
index 953b49a..fcd0d50 100644
--- a/mbtrack2/tracking/wakepotential.py
+++ b/mbtrack2/tracking/wakepotential.py
@@ -162,9 +162,7 @@ class WakePotential(Element):
             dipole = np.insert(dipole, 0, np.zeros(N+1))
             
         # Interpole on tau0 to get the same size as W0
-        interp_dipole = interp1d(self.tau, dipole, fill_value=0, 
-                             bounds_error=False)
-        dipole0 = interp_dipole(tau0)
+        dipole0 = np.interp(tau0, self.tau, dipole, 0, 0)
             
         setattr(self, "dipole_" + plane, dipole0)
         return dipole0
@@ -270,10 +268,7 @@ class WakePotential(Element):
 
         (tau0, dtau0, W0) = self.prepare_wakefunction(wake_type, self.tau)
         
-        interp_profile = interp1d(self.tau, self.rho, fill_value=0, 
-                                     bounds_error=False)
-        
-        profile0 = interp_profile(tau0)
+        profile0 = np.interp(tau0, self.tau, self.rho, 0, 0)
         
         if wake_type == "Wlong" or wake_type == "Wxquad" or wake_type == "Wyquad":
             Wp = signal.convolve(profile0, W0*-1, mode='same')*dtau0
@@ -307,8 +302,7 @@ class WakePotential(Element):
             self.charge_density(bunch)
             for wake_type in self.types:
                 tau0, Wp = self.get_wakepotential(bunch, wake_type)
-                f = interp1d(tau0 + self.tau_mean, Wp, fill_value = 0, bounds_error = False)
-                Wp_interp = f(bunch["tau"])
+                Wp_interp = np.interp(bunch["tau"], tau0 + self.tau_mean, Wp, 0, 0)
                 if wake_type == "Wlong":
                     bunch["delta"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wxdip":
-- 
GitLab