From b54a0baeeafe6f41865b573af5f5891f6c8ad6f2 Mon Sep 17 00:00:00 2001
From: Keon Hee KIM <alphaover2pi@korea.ac.kr>
Date: Wed, 31 Jul 2024 15:13:08 +0200
Subject: [PATCH] Optimize boolean indexing in CircularResistiveWall

---
 mbtrack2/impedance/resistive_wall.py | 25 +++++++++++--------------
 1 file changed, 11 insertions(+), 14 deletions(-)

diff --git a/mbtrack2/impedance/resistive_wall.py b/mbtrack2/impedance/resistive_wall.py
index 91fa7e1..3d63d32 100644
--- a/mbtrack2/impedance/resistive_wall.py
+++ b/mbtrack2/impedance/resistive_wall.py
@@ -163,18 +163,17 @@ class CircularResistiveWall(WakeField):
         in high-energy particle accelerators. World Scientific.
         """
         wl = np.zeros_like(time)
-        idx1 = time < 0
         if exact == True:
+            idx1 = time > 0
             idx2 = time == 0
-            idx3 = np.logical_not(np.logical_or(idx1, idx2))
             factor = (self.Z0 * c / (3 * np.pi * self.radius**2) * self.length)
-            if np.any(idx2):
-                # fundamental theorem of beam loading
-                wl[idx2] = 3 * factor / 2
-            wl[idx3] = self.__LongWakeExact(time[idx3], factor)
+            wl[idx1] = self.__LongWakeExact(time[idx1], factor)
+
+            # fundamental theorem of beam loading
+            wl[idx2] = 3 * factor / 2
         else:
-            idx2 = np.logical_not(idx1)
-            wl[idx2] = self.__LongWakeApprox(time[idx2])
+            idx = time >= 0
+            wl[idx] = self.__LongWakeApprox(time[idx])
         return wl
 
     def TransverseWakeFunction(self, time, exact=True):
@@ -220,16 +219,14 @@ class CircularResistiveWall(WakeField):
         Detectors and Associated Equipment 806 (2016): 221-230.
         """
         wt = np.zeros_like(time)
-        idx1 = time < 0
         if exact == True:
-            idx2 = time == 0
-            idx3 = np.logical_not(np.logical_or(idx1, idx2))
+            idx = time > 0
             factor = ((self.Z0 * c**2 * self.t0) /
                       (3 * np.pi * self.radius**4) * self.length)
-            wt[idx3] = self.__TransWakeExact(time[idx3], factor)
+            wt[idx] = self.__TransWakeExact(time[idx], factor)
         else:
-            idx2 = np.logical_not(idx1)
-            wt[idx2] = self.__TransWakeApprox(time[idx2])
+            idx = time >= 0
+            wt[idx] = self.__TransWakeApprox(time[idx])
         return wt
 
     def __LongWakeExact(self, t, factor):
-- 
GitLab