diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index 7fab92b56d9c37dabb6f3bdd2f110e32dfd490fe..821872928198f84b749cd4cd7e9a7578440d737a 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -8,7 +8,7 @@ included in the tracking.
 from abc import ABCMeta, abstractmethod
 from copy import deepcopy
 from functools import wraps
-
+import torch
 import numpy as np
 
 from mbtrack2.tracking.particles import Beam
@@ -156,10 +156,10 @@ class TransverseMap(Element):
     """
     def __init__(self, ring):
         self.ring = ring
-        self.alpha = self.ring.optics.local_alpha
-        self.beta = self.ring.optics.local_beta
-        self.gamma = self.ring.optics.local_gamma
-        self.dispersion = self.ring.optics.local_dispersion
+        self.alpha = torch.tensor(self.ring.optics.local_alpha)
+        self.beta = torch.tensor(self.ring.optics.local_beta)
+        self.gamma = torch.tensor(self.ring.optics.local_gamma)
+        self.dispersion = torch.tensor(self.ring.optics.local_dispersion)
         if self.ring.adts is not None:
             self.adts_poly = [
                 np.poly1d(self.ring.adts[0]),
@@ -201,11 +201,10 @@ class TransverseMap(Element):
                 self.adts_poly[1](Jx) + self.adts_poly[3](Jy))
 
         # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
-        matrix = np.zeros((6, 6, len(bunch)), dtype=np.float64)
-
+        matrix = torch.zeros((6, 6, len(bunch)), dtype=torch.float64)
         # Horizontal
-        c_x = np.cos(phase_advance_x)
-        s_x = np.sin(phase_advance_x)
+        c_x = torch.cos(phase_advance_x)
+        s_x = torch.sin(phase_advance_x)
 
         matrix[0, 0, :] = c_x + self.alpha[0] * s_x
         matrix[0, 1, :] = self.beta[0] * s_x
@@ -216,8 +215,8 @@ class TransverseMap(Element):
         matrix[2, 2, :] = 1
 
         # Vertical
-        c_y = np.cos(phase_advance_y)
-        s_y = np.sin(phase_advance_y)
+        c_y = torch.cos(phase_advance_y)
+        s_y = torch.sin(phase_advance_y)
 
         matrix[3, 3, :] = c_y + self.alpha[1] * s_y
         matrix[3, 4, :] = self.beta[1] * s_y
@@ -226,7 +225,7 @@ class TransverseMap(Element):
         matrix[4, 4, :] = c_y - self.alpha[1] * s_y
         matrix[4, 5, :] = self.dispersion[3]
         matrix[5, 5, :] = 1
-
+        
         x = matrix[0, 0, :] * bunch["x"] + matrix[
             0, 1, :] * bunch["xp"] + matrix[0, 2, :] * bunch["delta"]
         xp = matrix[1, 0, :] * bunch["x"] + matrix[
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index f334be881a7fd0a559eadb6826fb30e9b8835b4b..9e28a7341a95ac83241a287c8e71fe6e55d3b257 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
 import seaborn as sns
+import torch
 from scipy.constants import c, e, m_e, m_p
 
 
@@ -129,18 +130,18 @@ class Bunch:
         self._mp_number = int(mp_number)
 
         self.particles = {
-            "x": np.zeros(self.mp_number, dtype=np.float64),
-            "xp": np.zeros(self.mp_number, dtype=np.float64),
-            "y": np.zeros(self.mp_number, dtype=np.float64),
-            "yp": np.zeros(self.mp_number, dtype=np.float64),
-            "tau": np.zeros(self.mp_number, dtype=np.float64),
-            "delta": np.zeros(self.mp_number, dtype=np.float64),
+            "x": torch.zeros(self.mp_number, dtype=torch.float64),
+            "xp": torch.zeros(self.mp_number, dtype=torch.float64),
+            "y": torch.zeros(self.mp_number, dtype=torch.float64),
+            "yp": torch.zeros(self.mp_number, dtype=torch.float64),
+            "tau": torch.zeros(self.mp_number, dtype=torch.float64),
+            "delta": torch.zeros(self.mp_number, dtype=torch.float64),
         }
         self.track_alive = track_alive
-        self.alive = np.ones((self.mp_number, ), dtype=bool)
+        self.alive = torch.ones(self.mp_number, dtype=torch.bool)
         self.current = current
         if not alive:
-            self.alive = np.zeros((self.mp_number, ), dtype=bool)
+            self.alive = torch.zeros(self.mp_number, dtype=torch.bool)
 
         if load_from_file is not None:
             self.load(load_from_file, load_suffix, track_alive)
@@ -296,14 +297,14 @@ class Bunch:
 
         """
         if mean is None:
-            mean = np.zeros((6, ))
+            mean = torch.zeros(6)
 
         if cov is None:
             sigma_0 = kwargs.get("sigma_0", self.ring.sigma_0)
             sigma_delta = kwargs.get("sigma_delta", self.ring.sigma_delta)
             optics = kwargs.get("optics", self.ring.optics)
 
-            cov = np.zeros((6, 6))
+            cov = torch.zeros(6, 6)
             cov[0, 0] = self.ring.emit[0] * optics.local_beta[0] + (
                 optics.local_dispersion[0] * self.ring.sigma_delta)**2
             cov[1, 1] = self.ring.emit[0] * optics.local_gamma[0] + (
@@ -335,7 +336,9 @@ class Bunch:
             cov[4, 4] = sigma_0**2
             cov[5, 5] = sigma_delta**2
 
-        values = np.random.multivariate_normal(mean, cov, size=self.mp_number)
+        # values = np.random.multivariate_normal(mean, cov, size=self.mp_number)
+        dist = torch.distributions.MultivariateNormal(mean, cov)
+        values = dist.sample((self.mp_number,))
         self.particles["x"] = values[:, 0]
         self.particles["xp"] = values[:, 1]
         self.particles["y"] = values[:, 2]
@@ -366,16 +369,17 @@ class Bunch:
             Center of each bin.
 
         """
+        self[dimension] = torch.Tensor.contiguous(self[dimension])
         bin_min = self[dimension].min()
         bin_min = min(bin_min * 0.99, bin_min * 1.01)
         bin_max = self[dimension].max()
         bin_max = max(bin_max * 0.99, bin_max * 1.01)
 
-        bins = np.linspace(bin_min, bin_max, n_bin)
+        bins = torch.linspace(bin_min, bin_max, n_bin)
         center = (bins[1:] + bins[:-1]) / 2
-        sorted_index = np.searchsorted(bins, self[dimension], side='left')
+        sorted_index = torch.searchsorted(bins, self[dimension])
         sorted_index -= 1
-        profile = np.bincount(sorted_index, minlength=n_bin - 1)
+        profile = torch.bincount(sorted_index, minlength=n_bin-1)
 
         return (bins, sorted_index, profile, center)
 
diff --git a/mbtrack2/tracking/synchrotron.py b/mbtrack2/tracking/synchrotron.py
index 24fe189a8833c0c7ebde7ed24f77989c7ecfef88..f287eec9ad1a55257c0181e90b644d32d3460acc 100644
--- a/mbtrack2/tracking/synchrotron.py
+++ b/mbtrack2/tracking/synchrotron.py
@@ -6,6 +6,7 @@ Module where the Synchrotron class is defined.
 import matplotlib.pyplot as plt
 import numpy as np
 from scipy.constants import c, e
+import torch
 
 
 class Synchrotron:
@@ -293,7 +294,7 @@ class Synchrotron:
             Momentum compaction.
 
         """
-        return self.mcf(delta) - 1 / (self.gamma**2)
+        return torch.tensor(self.mcf(delta)) - 1 / (self.gamma**2)
 
     def sigma(self, position=None):
         """
diff --git a/mbtrack2/tracking/wakepotential.py b/mbtrack2/tracking/wakepotential.py
index 46abfbd835b7bbd8e9fee32516c0e22010aea4f1..8fad8b4bba694603c9bf2c227d74fe55f4b9b56b 100644
--- a/mbtrack2/tracking/wakepotential.py
+++ b/mbtrack2/tracking/wakepotential.py
@@ -7,10 +7,9 @@ deal with the single bunch and multi-bunch wakes.
 import matplotlib.pyplot as plt
 import numpy as np
 import pandas as pd
+import torch
 from scipy import signal
 from scipy.constants import c, mu_0, pi
-from scipy.interpolate import interp1d
-
 from mbtrack2.tracking.element import Element
 from mbtrack2.utilities.spectrum import gaussian_bunch
 
@@ -148,11 +147,11 @@ class WakePotential(Element):
             Dipole moment of the bunch.
 
         """
-        dipole = np.empty((self.n_bin - 1, ))
+        dipole = torch.zeros(self.n_bin - 1)
         for i in range(self.n_bin - 1):
             dipole[i] = bunch[plane][self.sorted_index == i].sum()
         dipole = dipole / self.profile
-        dipole[np.isnan(dipole)] = 0
+        dipole[torch.isnan(dipole)] = 0
 
         # Add N values to get same size as tau/profile
         if self.n_bin % 2 == 0:
@@ -308,7 +307,7 @@ class WakePotential(Element):
                 else:
                     Wp_interp = np.interp(self.center, tau0 + self.tau_mean, Wp,
                                           0, 0)
-                    Wp_interp= Wp_interp[self.sorted_index]
+                    Wp_interp= torch.tensor(Wp_interp[self.sorted_index])
                 if wake_type == "Wlong":
                     bunch["delta"] += Wp_interp * bunch.charge / self.ring.E0
                 elif wake_type == "Wxdip":