diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py index 7fab92b56d9c37dabb6f3bdd2f110e32dfd490fe..821872928198f84b749cd4cd7e9a7578440d737a 100644 --- a/mbtrack2/tracking/element.py +++ b/mbtrack2/tracking/element.py @@ -8,7 +8,7 @@ included in the tracking. from abc import ABCMeta, abstractmethod from copy import deepcopy from functools import wraps - +import torch import numpy as np from mbtrack2.tracking.particles import Beam @@ -156,10 +156,10 @@ class TransverseMap(Element): """ def __init__(self, ring): self.ring = ring - self.alpha = self.ring.optics.local_alpha - self.beta = self.ring.optics.local_beta - self.gamma = self.ring.optics.local_gamma - self.dispersion = self.ring.optics.local_dispersion + self.alpha = torch.tensor(self.ring.optics.local_alpha) + self.beta = torch.tensor(self.ring.optics.local_beta) + self.gamma = torch.tensor(self.ring.optics.local_gamma) + self.dispersion = torch.tensor(self.ring.optics.local_dispersion) if self.ring.adts is not None: self.adts_poly = [ np.poly1d(self.ring.adts[0]), @@ -201,11 +201,10 @@ class TransverseMap(Element): self.adts_poly[1](Jx) + self.adts_poly[3](Jy)) # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta) - matrix = np.zeros((6, 6, len(bunch)), dtype=np.float64) - + matrix = torch.zeros((6, 6, len(bunch)), dtype=torch.float64) # Horizontal - c_x = np.cos(phase_advance_x) - s_x = np.sin(phase_advance_x) + c_x = torch.cos(phase_advance_x) + s_x = torch.sin(phase_advance_x) matrix[0, 0, :] = c_x + self.alpha[0] * s_x matrix[0, 1, :] = self.beta[0] * s_x @@ -216,8 +215,8 @@ class TransverseMap(Element): matrix[2, 2, :] = 1 # Vertical - c_y = np.cos(phase_advance_y) - s_y = np.sin(phase_advance_y) + c_y = torch.cos(phase_advance_y) + s_y = torch.sin(phase_advance_y) matrix[3, 3, :] = c_y + self.alpha[1] * s_y matrix[3, 4, :] = self.beta[1] * s_y @@ -226,7 +225,7 @@ class TransverseMap(Element): matrix[4, 4, :] = c_y - self.alpha[1] * s_y matrix[4, 5, :] = self.dispersion[3] matrix[5, 5, :] = 1 - + x = matrix[0, 0, :] * bunch["x"] + matrix[ 0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"] xp = matrix[1, 0, :] * bunch["x"] + matrix[ diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py index f334be881a7fd0a559eadb6826fb30e9b8835b4b..9e28a7341a95ac83241a287c8e71fe6e55d3b257 100644 --- a/mbtrack2/tracking/particles.py +++ b/mbtrack2/tracking/particles.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns +import torch from scipy.constants import c, e, m_e, m_p @@ -129,18 +130,18 @@ class Bunch: self._mp_number = int(mp_number) self.particles = { - "x": np.zeros(self.mp_number, dtype=np.float64), - "xp": np.zeros(self.mp_number, dtype=np.float64), - "y": np.zeros(self.mp_number, dtype=np.float64), - "yp": np.zeros(self.mp_number, dtype=np.float64), - "tau": np.zeros(self.mp_number, dtype=np.float64), - "delta": np.zeros(self.mp_number, dtype=np.float64), + "x": torch.zeros(self.mp_number, dtype=torch.float64), + "xp": torch.zeros(self.mp_number, dtype=torch.float64), + "y": torch.zeros(self.mp_number, dtype=torch.float64), + "yp": torch.zeros(self.mp_number, dtype=torch.float64), + "tau": torch.zeros(self.mp_number, dtype=torch.float64), + "delta": torch.zeros(self.mp_number, dtype=torch.float64), } self.track_alive = track_alive - self.alive = np.ones((self.mp_number, ), dtype=bool) + self.alive = torch.ones(self.mp_number, dtype=torch.bool) self.current = current if not alive: - self.alive = np.zeros((self.mp_number, ), dtype=bool) + self.alive = torch.zeros(self.mp_number, dtype=torch.bool) if load_from_file is not None: self.load(load_from_file, load_suffix, track_alive) @@ -296,14 +297,14 @@ class Bunch: """ if mean is None: - mean = np.zeros((6, )) + mean = torch.zeros(6) if cov is None: sigma_0 = kwargs.get("sigma_0", self.ring.sigma_0) sigma_delta = kwargs.get("sigma_delta", self.ring.sigma_delta) optics = kwargs.get("optics", self.ring.optics) - cov = np.zeros((6, 6)) + cov = torch.zeros(6, 6) cov[0, 0] = self.ring.emit[0] * optics.local_beta[0] + ( optics.local_dispersion[0] * self.ring.sigma_delta)**2 cov[1, 1] = self.ring.emit[0] * optics.local_gamma[0] + ( @@ -335,7 +336,9 @@ class Bunch: cov[4, 4] = sigma_0**2 cov[5, 5] = sigma_delta**2 - values = np.random.multivariate_normal(mean, cov, size=self.mp_number) + # values = np.random.multivariate_normal(mean, cov, size=self.mp_number) + dist = torch.distributions.MultivariateNormal(mean, cov) + values = dist.sample((self.mp_number,)) self.particles["x"] = values[:, 0] self.particles["xp"] = values[:, 1] self.particles["y"] = values[:, 2] @@ -366,16 +369,17 @@ class Bunch: Center of each bin. """ + self[dimension] = torch.Tensor.contiguous(self[dimension]) bin_min = self[dimension].min() bin_min = min(bin_min * 0.99, bin_min * 1.01) bin_max = self[dimension].max() bin_max = max(bin_max * 0.99, bin_max * 1.01) - bins = np.linspace(bin_min, bin_max, n_bin) + bins = torch.linspace(bin_min, bin_max, n_bin) center = (bins[1:] + bins[:-1]) / 2 - sorted_index = np.searchsorted(bins, self[dimension], side='left') + sorted_index = torch.searchsorted(bins, self[dimension]) sorted_index -= 1 - profile = np.bincount(sorted_index, minlength=n_bin - 1) + profile = torch.bincount(sorted_index, minlength=n_bin-1) return (bins, sorted_index, profile, center) diff --git a/mbtrack2/tracking/synchrotron.py b/mbtrack2/tracking/synchrotron.py index 24fe189a8833c0c7ebde7ed24f77989c7ecfef88..f287eec9ad1a55257c0181e90b644d32d3460acc 100644 --- a/mbtrack2/tracking/synchrotron.py +++ b/mbtrack2/tracking/synchrotron.py @@ -6,6 +6,7 @@ Module where the Synchrotron class is defined. import matplotlib.pyplot as plt import numpy as np from scipy.constants import c, e +import torch class Synchrotron: @@ -293,7 +294,7 @@ class Synchrotron: Momentum compaction. """ - return self.mcf(delta) - 1 / (self.gamma**2) + return torch.tensor(self.mcf(delta)) - 1 / (self.gamma**2) def sigma(self, position=None): """ diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py index 46abfbd835b7bbd8e9fee32516c0e22010aea4f1..8fad8b4bba694603c9bf2c227d74fe55f4b9b56b 100644 --- a/mbtrack2/tracking/wakepotential.py +++ b/mbtrack2/tracking/wakepotential.py @@ -7,10 +7,9 @@ deal with the single bunch and multi-bunch wakes. import matplotlib.pyplot as plt import numpy as np import pandas as pd +import torch from scipy import signal from scipy.constants import c, mu_0, pi -from scipy.interpolate import interp1d - from mbtrack2.tracking.element import Element from mbtrack2.utilities.spectrum import gaussian_bunch @@ -148,11 +147,11 @@ class WakePotential(Element): Dipole moment of the bunch. """ - dipole = np.empty((self.n_bin - 1, )) + dipole = torch.zeros(self.n_bin - 1) for i in range(self.n_bin - 1): dipole[i] = bunch[plane][self.sorted_index == i].sum() dipole = dipole / self.profile - dipole[np.isnan(dipole)] = 0 + dipole[torch.isnan(dipole)] = 0 # Add N values to get same size as tau/profile if self.n_bin % 2 == 0: @@ -308,7 +307,7 @@ class WakePotential(Element): else: Wp_interp = np.interp(self.center, tau0 + self.tau_mean, Wp, 0, 0) - Wp_interp= Wp_interp[self.sorted_index] + Wp_interp= torch.tensor(Wp_interp[self.sorted_index]) if wake_type == "Wlong": bunch["delta"] += Wp_interp * bunch.charge / self.ring.E0 elif wake_type == "Wxdip":