From 5fe518a5e19a169f0e5b8d469afa4c2f3853266c Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 17 Jul 2024 18:02:14 +0200
Subject: [PATCH] Relax test values and formatting

---
 mbtrack2/utilities/misc.py | 116 +++++++++++++++++++++----------------
 tests/test_bunch.py        |  49 ++++++++--------
 2 files changed, 89 insertions(+), 76 deletions(-)

diff --git a/mbtrack2/utilities/misc.py b/mbtrack2/utilities/misc.py
index 6aaabef..f884b17 100644
--- a/mbtrack2/utilities/misc.py
+++ b/mbtrack2/utilities/misc.py
@@ -8,7 +8,7 @@ from pathlib import Path
 import numpy as np
 import pandas as pd
 from scipy.interpolate import interp1d
-from scipy.special import gamma, factorial, hyp2f1
+from scipy.special import factorial, gamma, hyp2f1
 
 from mbtrack2.impedance.wakefield import Impedance
 from mbtrack2.utilities.spectrum import spectral_density
@@ -199,23 +199,23 @@ def yokoya_elliptic(x_radius, y_radius):
     of beam pipes of general cross section." Phys. Rev. E 47, 656 (1993).
 
     """
-    
+
     if (x_radius <= 0) or (y_radius <= 0):
         raise ValueError("Both radii must be non-zero positive values.")
     elif (x_radius == np.inf) and (y_radius == np.inf):
         raise ValueError("Both radii have infinite values.")
     elif x_radius == np.inf:
         yoklong = 1.0
-        yokxdip = np.pi**2/24
-        yokydip = np.pi**2/12
-        yokxquad = -np.pi**2/24
-        yokyquad = np.pi**2/24
+        yokxdip = np.pi**2 / 24
+        yokydip = np.pi**2 / 12
+        yokxquad = -np.pi**2 / 24
+        yokyquad = np.pi**2 / 24
     elif y_radius == np.inf:
         yoklong = 1.0
-        yokxdip = np.pi**2/12
-        yokydip = np.pi**2/24
-        yokxquad = np.pi**2/24
-        yokyquad = -np.pi**2/24
+        yokxdip = np.pi**2 / 12
+        yokydip = np.pi**2 / 24
+        yokxquad = np.pi**2 / 24
+        yokyquad = -np.pi**2 / 24
     else:
         if y_radius < x_radius:
             small_semiaxis = y_radius
@@ -224,7 +224,7 @@ def yokoya_elliptic(x_radius, y_radius):
             small_semiaxis = x_radius
             large_semiaxis = y_radius
 
-        qr = (large_semiaxis - small_semiaxis) / (large_semiaxis + small_semiaxis)
+        qr = (large_semiaxis-small_semiaxis) / (large_semiaxis+small_semiaxis)
 
         if qr == 0:
             yoklong = 1.0
@@ -236,43 +236,53 @@ def yokoya_elliptic(x_radius, y_radius):
         else:
             # Define form factor functions
             def function_ff(small_semiaxis, F, mu_b, ip, il):
-                coeff_fflong = 2*np.sqrt(2)*small_semiaxis/(np.pi*F)
-                coeff_fftrans = np.sqrt(2)*small_semiaxis**3/(np.pi*F**3)
-
-                fflong = ( coeff_fflong*(-1)**ip/np.cosh(2*ip*mu_b)
-                        * (-1)**il/np.cosh(2*il*mu_b) )
-                ffdx = ( coeff_fftrans*(-1)**ip*(2*ip+1)/np.cosh((2*ip+1)*mu_b)
-                        * (-1)**il*(2*il+1)/np.cosh((2*il+1)*mu_b) )
-                ffdy = ( coeff_fftrans*(-1)**ip*(2*ip+1)/np.sinh((2*ip+1)*mu_b)
-                        * (-1)**il*(2*il+1)/np.sinh((2*il+1)*mu_b) )
-                ffquad = ( coeff_fftrans*(-1)**ip*(2*ip)**2/np.cosh(2*ip*mu_b)
-                        * (-1)**il/np.cosh(2*il*mu_b) )
-            
+                coeff_fflong = 2 * np.sqrt(2) * small_semiaxis / (np.pi * F)
+                coeff_fftrans = np.sqrt(2) * small_semiaxis**3 / (np.pi * F**3)
+
+                fflong = (coeff_fflong * (-1)**ip / np.cosh(2 * ip * mu_b) *
+                          (-1)**il / np.cosh(2 * il * mu_b))
+                ffdx = (coeff_fftrans * (-1)**ip * (2*ip + 1) / np.cosh(
+                    (2*ip + 1) * mu_b) * (-1)**il * (2*il + 1) / np.cosh(
+                        (2*il + 1) * mu_b))
+                ffdy = (coeff_fftrans * (-1)**ip * (2*ip + 1) / np.sinh(
+                    (2*ip + 1) * mu_b) * (-1)**il * (2*il + 1) / np.sinh(
+                        (2*il + 1) * mu_b))
+                ffquad = (coeff_fftrans * (-1)**ip * (2 * ip)**2 /
+                          np.cosh(2 * ip * mu_b) * (-1)**il /
+                          np.cosh(2 * il * mu_b))
+
                 return (fflong, ffdx, ffdy, ffquad)
 
             def function_L(mu_b, ip, il):
-                common_L = ( np.sqrt(2)*np.pi*np.exp(-(2*abs(ip-il)+1)*mu_b)
-                        * gamma(0.5+abs(ip-il))/(gamma(0.5)*factorial(abs(ip-il)))
-                        * hyp2f1(0.5, abs(ip-il)+0.5, abs(ip-il)+1, np.exp(-4*mu_b)) )       
-                val_m = ( np.sqrt(2)*np.pi*np.exp(-(2*ip+2*il+1)*mu_b)
-                        * gamma(0.5+ip+il)/(gamma(0.5)*factorial(ip+il))
-                        * hyp2f1(0.5, ip+il+0.5, ip+il+1, np.exp(-4*mu_b)) )
-                val_d = ( np.sqrt(2)*np.pi*np.exp(-(2*ip+2*il+3)*mu_b)
-                        * gamma(1.5+ip+il)/(gamma(0.5)*factorial(ip+il+1))
-                        * hyp2f1(0.5, ip+il+1.5, ip+il+2, np.exp(-4*mu_b)) )
-            
+                common_L = (np.sqrt(2) * np.pi *
+                            np.exp(-(2 * abs(ip - il) + 1) * mu_b) *
+                            gamma(0.5 + abs(ip - il)) /
+                            (gamma(0.5) * factorial(abs(ip - il))) *
+                            hyp2f1(0.5,
+                                   abs(ip - il) + 0.5,
+                                   abs(ip - il) + 1, np.exp(-4 * mu_b)))
+                val_m = (
+                    np.sqrt(2) * np.pi * np.exp(-(2*ip + 2*il + 1) * mu_b) *
+                    gamma(0.5 + ip + il) / (gamma(0.5) * factorial(ip + il)) *
+                    hyp2f1(0.5, ip + il + 0.5, ip + il + 1, np.exp(-4 * mu_b)))
+                val_d = (
+                    np.sqrt(2) * np.pi * np.exp(-(2*ip + 2*il + 3) * mu_b) *
+                    gamma(1.5 + ip + il) /
+                    (gamma(0.5) * factorial(ip + il + 1)) *
+                    hyp2f1(0.5, ip + il + 1.5, ip + il + 2, np.exp(-4 * mu_b)))
+
                 Lm = common_L + val_m
                 Ldx = common_L + val_d
                 Ldy = common_L - val_d
 
                 return (Lm, Ldx, Ldy)
-            
+
             ip_range = np.arange(51)
             il_range = np.arange(51)
             ip, il = np.meshgrid(ip_range, il_range, indexing="ij")
 
-            coeff_long = np.where( (ip == 0) & (il == 0), 0.25,
-                                    np.where((ip == 0) | (il == 0), 0.5, 1.0) )
+            coeff_long = np.where((ip == 0) & (il == 0), 0.25,
+                                  np.where((ip == 0) | (il == 0), 0.5, 1.0))
             coeff_quad = np.where(il == 0, 0.5, 1.0)
 
             # Our equations are approximately valid only for qr (ratio) values
@@ -281,11 +291,12 @@ def yokoya_elliptic(x_radius, y_radius):
 
             if qr < qr_th:
                 F = np.sqrt(large_semiaxis**2 - small_semiaxis**2)
-                mu_b = np.arccosh(large_semiaxis/F)
+                mu_b = np.arccosh(large_semiaxis / F)
 
-                ff_values = np.array(function_ff(small_semiaxis, F, mu_b, ip, il))
+                ff_values = np.array(
+                    function_ff(small_semiaxis, F, mu_b, ip, il))
                 L_values = np.array(function_L(mu_b, ip, il))
-            
+
                 yoklong = np.sum(coeff_long * ff_values[0] * L_values[0])
                 yokxdip = np.sum(ff_values[1] * L_values[1])
                 yokydip = np.sum(ff_values[2] * L_values[2])
@@ -295,31 +306,34 @@ def yokoya_elliptic(x_radius, y_radius):
                 if y_radius > x_radius:
                     yokxdip, yokydip = yokydip, yokxdip
                     yokxquad, yokyquad = yokyquad, yokxquad
-            
+
             # Beyond the threshold (qr >= qr_th), they may not be valid,
             # but should converge to Yokoya form factors of parallel plates.
             # Fortunately, beyond this threshold, they should show asymptotic behavior,
             # so we perform linear interpolation.
             else:
                 yoklong_pp = 1.0
-                yokxdip_pp = np.pi**2/24
-                yokydip_pp = np.pi**2/12
-                yokxquad_pp = -np.pi**2/24
-                yokyquad_pp = np.pi**2/24
+                yokxdip_pp = np.pi**2 / 24
+                yokydip_pp = np.pi**2 / 12
+                yokxquad_pp = -np.pi**2 / 24
+                yokyquad_pp = np.pi**2 / 24
 
-                small_semiaxis_th = large_semiaxis*(1-qr_th)/(1+qr_th)
+                small_semiaxis_th = large_semiaxis * (1-qr_th) / (1+qr_th)
                 F_th = np.sqrt(large_semiaxis**2 - small_semiaxis_th**2)
-                mu_b_th = np.arccosh(large_semiaxis/F_th)
+                mu_b_th = np.arccosh(large_semiaxis / F_th)
 
-                ff_values_th = np.array(function_ff(small_semiaxis_th, F_th, mu_b_th, ip, il))
+                ff_values_th = np.array(
+                    function_ff(small_semiaxis_th, F_th, mu_b_th, ip, il))
                 L_values_th = np.array(function_L(mu_b_th, ip, il))
 
-                yoklong_th = np.sum(coeff_long * ff_values_th[0] * L_values_th[0])
+                yoklong_th = np.sum(coeff_long * ff_values_th[0] *
+                                    L_values_th[0])
                 yokxdip_th = np.sum(ff_values_th[1] * L_values_th[1])
                 yokydip_th = np.sum(ff_values_th[2] * L_values_th[2])
-                yokxquad_th = -np.sum(coeff_quad * ff_values_th[3] * L_values_th[0])
+                yokxquad_th = -np.sum(
+                    coeff_quad * ff_values_th[3] * L_values_th[0])
                 yokyquad_th = -yokxquad_th
-            
+
                 if y_radius > x_radius:
                     yokxdip_th, yokydip_th = yokydip_th, yokxdip_th
                     yokxquad_th, yokyquad_th = yokyquad_th, yokxquad_th
@@ -338,7 +352,7 @@ def yokoya_elliptic(x_radius, y_radius):
                 yokydip = np.interp(qr, qr_array, yokydip_array)
                 yokxquad = np.interp(qr, qr_array, yokxquad_array)
                 yokyquad = np.interp(qr, qr_array, yokyquad_array)
-    
+
     return (yoklong, yokxdip, yokydip, yokxquad, yokyquad)
 
 
diff --git a/tests/test_bunch.py b/tests/test_bunch.py
index 5d3a69b..25189d3 100644
--- a/tests/test_bunch.py
+++ b/tests/test_bunch.py
@@ -8,9 +8,9 @@ from mbtrack2 import Bunch
 def test_bunch_values(demo_ring):
     mp_number = 10
     current = 20e-3
-    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current, 
+    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current,
                     track_alive=True)
-    
+
     assert mybunch.mp_number == mp_number
     assert pytest.approx(mybunch.current) == current
     assert len(mybunch) == mp_number
@@ -19,35 +19,35 @@ def test_bunch_values(demo_ring):
     assert pytest.approx(mybunch.charge_per_mp) == current * demo_ring.T0 / mp_number
     assert pytest.approx(mybunch.particle_number) == current * demo_ring.T0 / e
     assert mybunch.is_empty == False
-    
+
 def test_bunch_magic(mybunch):
     for label in mybunch:
         np.testing.assert_allclose(mybunch[label], np.zeros(len(mybunch)))
         mybunch[label] = np.ones(len(mybunch))
         np.testing.assert_allclose(mybunch[label], np.ones(len(mybunch)))
-    
+
 def test_bunch_losses(mybunch):
     charge_init = mybunch.charge
     mybunch.alive[0] = False
     assert len(mybunch) == mybunch.mp_number - 1
     assert pytest.approx(mybunch.charge) == charge_init * len(mybunch) / mybunch.mp_number
-    
+
 def test_bunch_init_gauss(large_bunch):
     large_bunch.init_gaussian(mean=np.ones((6,)))
     np.testing.assert_allclose(large_bunch.mean, np.ones((6,)), rtol=1e-2)
-    
+
 def test_bunch_save_load(mybunch, demo_ring, tmp_path):
     mybunch["x"] += 1
     mybunch.save(str(tmp_path / "test"))
-    
+
     mybunch2 = Bunch(demo_ring, mp_number=1, current=1e-5)
     mybunch2.load(str(tmp_path / "test.hdf5"))
-    
+
     assert mybunch.mp_number == mybunch2.mp_number
     assert pytest.approx(mybunch.charge) == mybunch2.charge
     for label in mybunch:
         np.testing.assert_allclose(mybunch[label], mybunch2[label])
-        
+
 def test_bunch_stats(demo_ring, large_bunch):
     large_bunch.init_gaussian()
     np.testing.assert_array_almost_equal(large_bunch.mean, np.zeros((6,)), decimal=5)
@@ -55,7 +55,7 @@ def test_bunch_stats(demo_ring, large_bunch):
     np.testing.assert_allclose(large_bunch.std, sig, rtol=1e-2)
     np.testing.assert_allclose(large_bunch.emit[:2], demo_ring.emit, rtol=1e-2)
     np.testing.assert_allclose(large_bunch.cs_invariant[:2], demo_ring.emit*2, rtol=1e-2)
-    
+
 def test_bunch_binning(mybunch):
     mybunch.init_gaussian()
     (bins, sorted_index, profile, center) = mybunch.binning()
@@ -64,7 +64,7 @@ def test_bunch_binning(mybunch):
         assert bins[val] <= mybunch["tau"][i] <= bins[val+1]
         profile0[val] += 1
     np.testing.assert_allclose(profile0, profile)
-    
+
 def test_bunch_plots(mybunch):
     mybunch.init_gaussian()
     mybunch.plot_phasespace()
@@ -73,34 +73,33 @@ def test_bunch_plots(mybunch):
 def test_bunch_emittance(demo_ring):
     mp_number = 1_000_000
     current = 1.2e-3
-    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current, 
-                    track_alive=False)    
+    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current,
+                    track_alive=False)
     mybunch.init_gaussian()
-    np.testing.assert_allclose(mybunch.emit[0], demo_ring.emit[0], rtol=0, atol=1e-10,
+    np.testing.assert_allclose(mybunch.emit[0], demo_ring.emit[0], rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {demo_ring.emit[0]} initialised, {mybunch.emit[0]:} calculated')
-    np.testing.assert_allclose(mybunch.emit[1], demo_ring.emit[1], rtol=0, atol=1e-10,
+    np.testing.assert_allclose(mybunch.emit[1], demo_ring.emit[1], rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {demo_ring.emit[1]} initialised, {mybunch.emit[1]:} calculated')
 
-    np.testing.assert_allclose(mybunch.emit[0], mybunch.cs_invariant[0]/2, rtol=0, atol=1e-10,
+    np.testing.assert_allclose(mybunch.emit[0], mybunch.cs_invariant[0]/2, rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {mybunch.cs_invariant[0]/2} calculated with optics functions, {mybunch.emit[0]:} calculated with coordinates only')
-    np.testing.assert_allclose(mybunch.emit[1], mybunch.cs_invariant[1]/2, rtol=0, atol=1e-10,
+    np.testing.assert_allclose(mybunch.emit[1], mybunch.cs_invariant[1]/2, rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {mybunch.cs_invariant[1]/2} calculated with optics functions, {mybunch.emit[1]:} calculated with coordinates only')
-    
+
 
 def test_bunch_emittance_with_dispersion(demo_ring):
     mp_number = 1_000_000
     current = 1.2e-3
     demo_ring.optics.local_dispersion = np.array([1e-2, 1e-3, 1e-2, 1e-3])
-    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current, 
-                    track_alive=False)    
+    mybunch = Bunch(demo_ring, mp_number=mp_number, current=current,
+                    track_alive=False)
     mybunch.init_gaussian()
-    np.testing.assert_allclose(mybunch.emit[0], demo_ring.emit[0], rtol=0, atol=1e-9,
+    np.testing.assert_allclose(mybunch.emit[0], demo_ring.emit[0], rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {demo_ring.emit[0]} initialised, {mybunch.emit[0]:} calculated')
-    np.testing.assert_allclose(mybunch.emit[1], demo_ring.emit[1], rtol=0, atol=1e-9,
+    np.testing.assert_allclose(mybunch.emit[1], demo_ring.emit[1], rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {demo_ring.emit[1]} initialised, {mybunch.emit[1]:} calculated')
 
-    np.testing.assert_allclose(mybunch.emit[0], mybunch.cs_invariant[0]/2, rtol=0, atol=1e-9,
+    np.testing.assert_allclose(mybunch.emit[0], mybunch.cs_invariant[0]/2, rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {mybunch.cs_invariant[0]/2} calculated with optics functions, {mybunch.emit[0]:} calculated with coordinates only')
-    np.testing.assert_allclose(mybunch.emit[1], mybunch.cs_invariant[1]/2, rtol=0, atol=1e-9,
+    np.testing.assert_allclose(mybunch.emit[1], mybunch.cs_invariant[1]/2, rtol=1e-2, atol=0,
      err_msg=f'Emittances do not match. {mybunch.cs_invariant[1]/2} calculated with optics functions, {mybunch.emit[1]:} calculated with coordinates only')
-    
\ No newline at end of file
-- 
GitLab