diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py index 299a354eb252fbb25b07ff455f4647528a01a938..0f69f6cfb391776d5f747b2230fb9a9c374df895 100644 --- a/mbtrack2/tracking/element.py +++ b/mbtrack2/tracking/element.py @@ -444,6 +444,10 @@ def transverse_map_sector_generator(ring, positions): """ N_sec = len(positions) sectors = [] + if hasattr(ring, "adts") and ring.adts is not None: + adts = np.array([val / N_sec for val in ring.adts]) + else: + adts = None if ring.optics.use_local_values: for i in range(N_sec): sectors.append( @@ -455,56 +459,54 @@ def transverse_map_sector_generator(ring, positions): ring.optics.local_beta, ring.optics.local_dispersion, 2 * np.pi * ring.tune / N_sec, - ring.chro / N_sec, - adts=ring.adts / - N_sec if ring.adts else None)) + np.asarray(ring.chro) / N_sec, + adts=adts)) else: import at - def _compute_chro(ring, pos, dp=1e-2, order=4): + def _compute_chro(ring, N_sec, dp=1e-2, order=4): lat = deepcopy(ring.optics.lattice) lat.append(at.Marker("END")) - fit, dpa, tune = at.physics.nonlinear.chromaticity(lat, - method='linopt', - dpm=dp, - n_points=100, - order=order) + fit, _, _ = at.physics.nonlinear.chromaticity(lat, + method='linopt', + dpm=dp, + n_points=100, + order=order) Chrox, Chroy = fit[0, 1:], fit[1, 1:] - chrox = np.interp(pos, s, Chrox) - chroy = np.interp(pos, s, Chroy) - + chrox = np.array([np.linspace(0, val, N_sec) for val in Chrox]) + chroy = np.array([np.linspace(0, val, N_sec) for val in Chroy]) return np.array([chrox, chroy]) + _chro = _compute_chro(ring, N_sec) for i in range(N_sec): alpha0 = ring.optics.alpha(positions[i]) beta0 = ring.optics.beta(positions[i]) dispersion0 = ring.optics.dispersion(positions[i]) mu0 = ring.optics.mu(positions[i]) - chro0 = _compute_chro(ring, positions[i]) + chro0 = _chro[:, i] if i != (N_sec - 1): alpha1 = ring.optics.alpha(positions[i + 1]) beta1 = ring.optics.beta(positions[i + 1]) dispersion1 = ring.optics.dispersion(positions[i + 1]) mu1 = ring.optics.mu(positions[i + 1]) - chro1 = _compute_chro(ring, positions[i + 1]) + chro1 = _chro[:, i + 1] else: alpha1 = ring.optics.alpha(positions[0]) beta1 = ring.optics.beta(positions[0]) dispersion1 = ring.optics.dispersion(positions[0]) mu1 = ring.optics.mu(ring.L) - chro1 = _compute_chro(ring, ring.L) + chro1 = _chro[:, -1] phase_diff = mu1 - mu0 chro_diff = chro1 - chro0 sectors.append( - TransverseMapSector( - ring, - alpha0, - beta0, - dispersion0, - alpha1, - beta1, - dispersion1, - phase_diff, - chro_diff, - )) + TransverseMapSector(ring, + alpha0, + beta0, + dispersion0, + alpha1, + beta1, + dispersion1, + phase_diff, + chro_diff, + adts=adts)) return sectors diff --git a/mbtrack2/tracking/parallel.py b/mbtrack2/tracking/parallel.py index a441f2d435fa7dff87e9e7c1116036903299cf32..c921f610feccbd1019f672f77a669bc3fcc0b4dc 100644 --- a/mbtrack2/tracking/parallel.py +++ b/mbtrack2/tracking/parallel.py @@ -81,7 +81,7 @@ class Mpi: Filling pattern of the beam, like Beam.filling_pattern """ if (filling_pattern.sum() != self.size): - raise ValueError("The number of processors must be equal to the" + raise ValueError("The number of processors must be equal to the " "number of (non-empty) bunches.") table = np.zeros((self.size, 2), dtype=int) table[:, 0] = np.arange(0, self.size) diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py index 0deea66f464cf496eb7fd75e0822bf5d64c95b05..cf872debb24e5241098a9299f899dd8a68219f6d 100644 --- a/mbtrack2/tracking/particles.py +++ b/mbtrack2/tracking/particles.py @@ -687,6 +687,8 @@ class Beam: Initialize beam with a given filling pattern and marco-particle number per bunch. Then initialize the different bunches with a 6D gaussian phase space. + init_bunch_list_mpi(bunch, filling_pattern) + Initialize a beam using MPI parallelisation with a Bunch per core. mpi_init() Switch on MPI parallelisation and initialise a Mpi object mpi_gather() @@ -714,6 +716,8 @@ class Beam: raise ValueError(("The length of the bunch list is {} ".format( len(bunch_list)) + "but should be {}".format(self.ring.h))) self.bunch_list = bunch_list + self.update_filling_pattern() + self.update_distance_between_bunches() def __len__(self): """Return the number of (not empty) bunches""" @@ -864,6 +868,40 @@ class Beam: for bunch in self.not_empty: bunch.init_gaussian() + def init_bunch_list_mpi(self, bunch, filling_pattern): + """ + Initialize a beam using MPI parallelisation with a Bunch per core. + + Parameters + ---------- + bunch : Bunch object + The bunch given should probably depend on the mpi.rank so that each + core can track a different bunch. + Example: beam.init_bunch_list_mpi(bunch_list[comm.rank], filling_pattern) + filling_pattern : array-like of bool of length ring.h + Filling pattern of the beam as a list or an array of bool. + + """ + filling_pattern = np.array(filling_pattern) + + if len(filling_pattern) != self.ring.h: + raise ValueError(("The length of filling pattern is {} ".format( + len(filling_pattern)) + + "but should be {}".format(self.ring.h))) + + if filling_pattern.dtype != np.dtype("bool"): + raise TypeError("dtype {} should be bool.".format( + filling_pattern.dtype)) + + self.bunch_list = [ + Bunch(self.ring, mp_number=1, alive=filling_pattern[i]) + for i in range(self.ring.h) + ] + self.update_filling_pattern() + self.update_distance_between_bunches() + self.mpi_init() + self[self.mpi.bunch_num] = bunch + def update_filling_pattern(self): """Update the beam filling pattern.""" filling_pattern = [] diff --git a/tests/unit/tracking/test_element.py b/tests/unit/tracking/test_element.py index 6fcbe8d15a38406e45adebbacfed1e19725892d1..6084010a573b6e3abd14ad32fba28d96c4f042eb 100644 --- a/tests/unit/tracking/test_element.py +++ b/tests/unit/tracking/test_element.py @@ -374,3 +374,54 @@ class TestTransverseMap: assert not np.array_equal(initial_values[attr], current[attr]) assert not np.array_equal(initial_values[attr], old[attr]) np.testing.assert_allclose(current[attr], old[attr]) + +class TestTransverseMapSectorGenerator: + + # Generate sectors with local optics values when use_local_values is True + def test_local_optics_values(self, demo_ring): + positions = np.array([0, 25, 50, 75]) + sectors = transverse_map_sector_generator(demo_ring, positions) + + assert len(sectors) == len(positions) + for sector in sectors: + assert np.array_equal(sector.alpha0, demo_ring.optics.local_alpha) + assert np.array_equal(sector.beta0, demo_ring.optics.local_beta) + assert np.array_equal(sector.dispersion0, demo_ring.optics.local_dispersion) + + # Generate sectors with AT lattice optics when use_local_values is False + def test_at_lattice_optics(self, ring_with_at_lattice): + positions = np.array([0, 25, 50, 75]) + sectors = transverse_map_sector_generator(ring_with_at_lattice, positions) + + assert len(sectors) == len(positions) + for i, sector in enumerate(sectors): + assert np.array_equal(sector.alpha0, ring_with_at_lattice.optics.alpha(positions[i])) + assert np.array_equal(sector.beta0, ring_with_at_lattice.optics.beta(positions[i])) + + # Calculate phase differences between consecutive positions + def test_phase_differences(self, ring_with_at_lattice): + positions = np.array([0, 25, 50]) + ring_with_at_lattice.optics.use_local_values = False + sectors = transverse_map_sector_generator(ring_with_at_lattice, positions) + + for i, sector in enumerate(sectors): + if i < len(positions)-1: + expected_phase = ring_with_at_lattice.optics.mu(positions[i+1]) - ring_with_at_lattice.optics.mu(positions[i]) + assert np.allclose(sector.tune_diff * (2 * np.pi), expected_phase) + + # Compute chromaticity differences between sectors + def test_chromaticity_differences(self, demo_ring): + positions = np.array([0, 50]) + sectors = transverse_map_sector_generator(demo_ring, positions) + + expected_chro = np.asarray(demo_ring.chro) / len(positions) + assert np.allclose(sectors[0].chro_diff, expected_chro) + + # Handle ADTS parameters correctly when provided + def test_adts_handling(self, demo_ring): + demo_ring.adts = np.array([1.0, 1.0, 1.0, 1.0]) + positions = np.array([0, 50]) + sectors = transverse_map_sector_generator(demo_ring, positions) + + expected_adts = demo_ring.adts / len(positions) + assert np.allclose(sectors[0].adts_poly[0].coef, expected_adts[0]) \ No newline at end of file diff --git a/tests/unit/tracking/test_particle.py b/tests/unit/tracking/test_particle.py index 0d527f752eace2f74ff9a584f789a745e7615ff4..2f0a174e8bed0ef2031270c3194cb38ced935567 100644 --- a/tests/unit/tracking/test_particle.py +++ b/tests/unit/tracking/test_particle.py @@ -304,4 +304,25 @@ class TestBeam: assert mock_mpi_instance.comm.allgather.called beam.mpi_close() assert not beam.mpi_switch - assert beam.mpi is None \ No newline at end of file + assert beam.mpi is None + + def test_init_bunch_list(self, demo_ring): + filling_pattern = np.ones((demo_ring.h,), dtype=bool) + filling_pattern[5] = False + filling_pattern[8:10] = False + bunch_list = [Bunch(demo_ring, mp_number=1, alive=filling_pattern[i]) for i in range(demo_ring.h)] + beam = Beam(demo_ring, bunch_list) + assert len(beam) == filling_pattern.sum() + np.testing.assert_array_equal(beam.filling_pattern, filling_pattern) + assert beam.distance_between_bunches is not None + + def test_init_bunch_list_mpi(self, demo_ring, generate_bunch): + filling_pattern = np.zeros((demo_ring.h,), dtype=bool) + filling_pattern[0] = True + beam = Beam(demo_ring) + beam.init_bunch_list_mpi(generate_bunch(), filling_pattern) + + assert len(beam) == 1 + np.testing.assert_array_equal(beam.filling_pattern, filling_pattern) + assert beam.distance_between_bunches[0] == demo_ring.h + \ No newline at end of file