From 7864c7e866b780fe2c8f7acaacf5881da461bf51 Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 11 Dec 2024 12:06:56 +0100
Subject: [PATCH] [Fix] CavityResonator.init_phasor_track and
 CavityResonator.init_phasor

---
 mbtrack2/tracking/rf.py        |  9 +++++----
 tests/unit/tracking/test_rf.py | 13 +++++++++++++
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/mbtrack2/tracking/rf.py b/mbtrack2/tracking/rf.py
index 3abc5dc..1bfdce7 100644
--- a/mbtrack2/tracking/rf.py
+++ b/mbtrack2/tracking/rf.py
@@ -387,10 +387,11 @@ class CavityResonator():
 
                 if beam.mpi_switch:
                     # get shared bunch profile for current bunch
+                    beam.mpi.share_distributions(beam, n_bin=self.n_bin)
                     center = beam.mpi.tau_center[j]
                     profile = beam.mpi.tau_profile[j]
-                    bin_length = beam.mpi.tau_bin_length[j]
-                    charge_per_mp = beam.mpi.charge_per_mp_all[j]
+                    bin_length = float(beam.mpi.tau_bin_length[j][0])
+                    charge_per_mp = float(beam.mpi.charge_per_mp_all[j])
                 else:
                     if i == 0:
                         # get bunch profile for current bunch
@@ -529,8 +530,8 @@ class CavityResonator():
                 beam.mpi.share_distributions(beam, n_bin=self.n_bin)
                 center[:, index] = beam.mpi.tau_center[j]
                 profile[:, index] = beam.mpi.tau_profile[j]
-                bin_length[index] = beam.mpi.bin_length[j]
-                charge_per_mp[index] = beam.mpi.charge_per_mp_all[j]
+                bin_length[index] = float(beam.mpi.tau_bin_length[j][0])
+                charge_per_mp[index] = float(beam.mpi.charge_per_mp_all[j])
             else:
                 (bins[:, index], sorted_index, profile[:, index],
                  center[:, index]) = bunch.binning(n_bin=self.n_bin)
diff --git a/tests/unit/tracking/test_rf.py b/tests/unit/tracking/test_rf.py
index 1f89e38..6c0cb0e 100644
--- a/tests/unit/tracking/test_rf.py
+++ b/tests/unit/tracking/test_rf.py
@@ -105,6 +105,19 @@ class TestCavityResonator:
         assert phasor_init_phasor_track != init_phasor
         assert np.allclose(phasor_init_phasor, phasor_init_phasor_track, rtol=1e-2)
         
+    def test_phasor_init_mpi(self, cav_res, beam_1bunch_mpi):
+        init_phasor = cav_res.beam_phasor
+        cav_res.init_phasor_track(beam_1bunch_mpi)
+        phasor_init_phasor_track = cav_res.beam_phasor
+        
+        cav_res.beam_phasor = init_phasor
+        cav_res.init_phasor(beam_1bunch_mpi)
+        phasor_init_phasor = cav_res.beam_phasor
+        
+        assert phasor_init_phasor != init_phasor
+        assert phasor_init_phasor_track != init_phasor
+        assert np.allclose(phasor_init_phasor, phasor_init_phasor_track, rtol=1e-2)
+        
     # Setting detune updates _detune, _fr, _wr, and _psi correctly
     def test_detune(self, cav_res):
         detune_value = 1000
-- 
GitLab