diff --git a/mbtrack2/tracking/element.py b/mbtrack2/tracking/element.py
index 4385b57d2f44e87d82648d2ba97831c22051f581..752c044c71322585f056a3d4e467a3e23274715c 100644
--- a/mbtrack2/tracking/element.py
+++ b/mbtrack2/tracking/element.py
@@ -10,10 +10,48 @@ from copy import deepcopy
 from functools import wraps
 
 import numpy as np
+from scipy.special import factorial
 
 from mbtrack2.tracking.particles import Beam
 
 
+def _compute_chromatic_phase_advances(chro, bunch):
+    order = len(chro) // 2
+    if order == 1:
+        phase_advance_x = chro[0] * bunch["delta"]
+        phase_advance_y = chro[1] * bunch["delta"]
+    elif order == 2:
+        phase_advance_x = (chro[0] * bunch["delta"] +
+                           chro[2] / 2 * bunch["delta"]**2)
+        phase_advance_y = (chro[1] * bunch["delta"] +
+                           chro[3] / 2 * bunch["delta"]**2)
+    elif order == 3:
+        phase_advance_x = (chro[0] * bunch["delta"] +
+                           chro[2] / 2 * bunch["delta"]**2 +
+                           chro[4] / 6 * bunch["delta"]**3)
+        phase_advance_y = (chro[1] * bunch["delta"] +
+                           chro[3] / 2 * bunch["delta"]**2 +
+                           chro[5] / 6 * bunch["delta"]**3)
+    elif order == 4:
+        phase_advance_x = (chro[0] * bunch["delta"] +
+                           chro[2] / 2 * bunch["delta"]**2 +
+                           chro[4] / 6 * bunch["delta"]**3 +
+                           chro[6] / 24 * bunch["delta"]**4)
+        phase_advance_y = (chro[1] * bunch["delta"] +
+                           chro[3] / 2 * bunch["delta"]**2 +
+                           chro[5] / 6 * bunch["delta"]**3 +
+                           chro[7] / 24 * bunch["delta"]**4)
+    else:
+        coefs = np.array([1 / factorial(i + 1) for i in range(order + 1)])
+        coefs[0] = 0
+        chro = np.concatenate(([0, 0], chro))
+        phase_advance_x = np.polynomial.polynomial.Polynomial(
+            chro[::2] * coefs)(bunch['delta'])
+        phase_advance_y = np.polynomial.polynomial.Polynomial(
+            chro[1::2] * coefs)(bunch['delta'])
+    return phase_advance_x, phase_advance_y
+
+
 class Element(metaclass=ABCMeta):
     """
     Abstract Element class used for subclass inheritance to define all kinds
@@ -147,109 +185,6 @@ class SynchrotronRadiation(Element):
                                self.ring.T0 / self.ring.tau[1])**0.5 * rand
 
 
-class TransverseMap(Element):
-    """
-    Transverse map for a single turn in the synchrotron.
-
-    Parameters
-    ----------
-    ring : Synchrotron object
-    """
-
-    def __init__(self, ring):
-        self.ring = ring
-        self.alpha = self.ring.optics.local_alpha
-        self.beta = self.ring.optics.local_beta
-        self.gamma = self.ring.optics.local_gamma
-        self.dispersion = self.ring.optics.local_dispersion
-        if self.ring.adts is not None:
-            self.adts_poly = [
-                np.poly1d(self.ring.adts[0]),
-                np.poly1d(self.ring.adts[1]),
-                np.poly1d(self.ring.adts[2]),
-                np.poly1d(self.ring.adts[3]),
-            ]
-
-    @Element.parallel
-    def track(self, bunch):
-        """
-        Tracking method for the element.
-        No bunch to bunch interaction, so written for Bunch objects and
-        @Element.parallel is used to handle Beam objects.
-
-        Parameters
-        ----------
-        bunch : Bunch or Beam object
-        """
-
-        # Compute phase advance which depends on energy via chromaticity and ADTS
-        if self.ring.adts is None:
-            phase_advance_x = (
-                2 * np.pi *
-                (self.ring.tune[0] + self.ring.chro[0] * bunch["delta"]))
-            phase_advance_y = (
-                2 * np.pi *
-                (self.ring.tune[1] + self.ring.chro[1] * bunch["delta"]))
-        else:
-            Jx = ((self.ring.optics.local_gamma[0] * bunch["x"]**2) +
-                  (2 * self.ring.optics.local_alpha[0] * bunch["x"] *
-                   bunch["xp"]) +
-                  (self.ring.optics.local_beta[0] * bunch["xp"]**2))
-            Jy = ((self.ring.optics.local_gamma[1] * bunch["y"]**2) +
-                  (2 * self.ring.optics.local_alpha[1] * bunch["y"] *
-                   bunch["yp"]) +
-                  (self.ring.optics.local_beta[1] * bunch["yp"]**2))
-            phase_advance_x = (
-                2 * np.pi *
-                (self.ring.tune[0] + self.ring.chro[0] * bunch["delta"] +
-                 self.adts_poly[0](Jx) + self.adts_poly[2](Jy)))
-            phase_advance_y = (
-                2 * np.pi *
-                (self.ring.tune[1] + self.ring.chro[1] * bunch["delta"] +
-                 self.adts_poly[1](Jx) + self.adts_poly[3](Jy)))
-
-        # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
-        matrix = np.zeros((6, 6, len(bunch)), dtype=np.float64)
-
-        # Horizontal
-        c_x = np.cos(phase_advance_x)
-        s_x = np.sin(phase_advance_x)
-
-        matrix[0, 0, :] = c_x + self.alpha[0] * s_x
-        matrix[0, 1, :] = self.beta[0] * s_x
-        matrix[0, 2, :] = self.dispersion[0]
-        matrix[1, 0, :] = -1 * self.gamma[0] * s_x
-        matrix[1, 1, :] = c_x - self.alpha[0] * s_x
-        matrix[1, 2, :] = self.dispersion[1]
-        matrix[2, 2, :] = 1
-
-        # Vertical
-        c_y = np.cos(phase_advance_y)
-        s_y = np.sin(phase_advance_y)
-
-        matrix[3, 3, :] = c_y + self.alpha[1] * s_y
-        matrix[3, 4, :] = self.beta[1] * s_y
-        matrix[3, 5, :] = self.dispersion[2]
-        matrix[4, 3, :] = -1 * self.gamma[1] * s_y
-        matrix[4, 4, :] = c_y - self.alpha[1] * s_y
-        matrix[4, 5, :] = self.dispersion[3]
-        matrix[5, 5, :] = 1
-
-        x = (matrix[0, 0] * bunch["x"] + matrix[0, 1] * bunch["xp"] +
-             matrix[0, 2] * bunch["delta"])
-        xp = (matrix[1, 0] * bunch["x"] + matrix[1, 1] * bunch["xp"] +
-              matrix[1, 2] * bunch["delta"])
-        y = (matrix[3, 3] * bunch["y"] + matrix[3, 4] * bunch["yp"] +
-             matrix[3, 5] * bunch["delta"])
-        yp = (matrix[4, 3] * bunch["y"] + matrix[4, 4] * bunch["yp"] +
-              matrix[4, 5] * bunch["delta"])
-
-        bunch["x"] = x
-        bunch["xp"] = xp
-        bunch["y"] = y
-        bunch["yp"] = yp
-
-
 class SkewQuadrupole:
     """
     Thin skew quadrupole element used to introduce betatron coupling (the
@@ -346,6 +281,32 @@ class TransverseMapSector(Element):
         else:
             self.adts_poly = None
 
+    def _compute_new_coords(self, bunch, phase_advance, plane):
+        if plane == 'x':
+            i, j, coord, mom = 0, 0, 'x', 'xp'
+        elif plane == 'y':
+            i, j, coord, mom = 1, 2, 'y', 'yp'
+        else:
+            raise ValueError('plane should be either x or y')
+        c_u = np.cos(2 * np.pi * phase_advance)
+        s_u = np.sin(2 * np.pi * phase_advance)
+        M00 = np.sqrt(
+            self.beta1[i] / self.beta0[i]) * (c_u + self.alpha0[i] * s_u)
+        M01 = np.sqrt(self.beta0[i] * self.beta1[i]) * s_u
+        M02 = (self.dispersion1[j] - M00 * self.dispersion0[j] -
+               M01 * self.dispersion0[j + 1])
+        M10 = ((self.alpha0[i] - self.alpha1[i]) * c_u -
+               (1 + self.alpha0[i] * self.alpha1[i]) * s_u) / np.sqrt(
+                   self.beta0[i] * self.beta1[i])
+        M11 = np.sqrt(
+            self.beta0[i] / self.beta1[i]) * (c_u - self.alpha1[i] * s_u)
+        M12 = (self.dispersion1[j + 1] - M10 * self.dispersion0[j] -
+               M11 * self.dispersion0[j + 1])
+        M22 = 1
+        u = (M00 * bunch[coord] + M01 * bunch[mom] + M02 * bunch["delta"])
+        up = (M10 * bunch[coord] + M11 * bunch[mom] + M12 * bunch["delta"])
+        return u, up
+
     @Element.parallel
     def track(self, bunch):
         """
@@ -357,85 +318,45 @@ class TransverseMapSector(Element):
         ----------
         bunch : Bunch or Beam object
         """
-
+        phase_advance_x = self.tune_diff[0]
+        phase_advance_y = self.tune_diff[1]
         # Compute phase advance which depends on energy via chromaticity and ADTS
-        if self.adts_poly is None:
-            phase_advance_x = (
-                2 * np.pi *
-                (self.tune_diff[0] + self.chro_diff[0] * bunch["delta"]))
-            phase_advance_y = (
-                2 * np.pi *
-                (self.tune_diff[1] + self.chro_diff[1] * bunch["delta"]))
-        else:
+        if (np.array(self.chro_diff) != 0).any():
+            phase_advance_x_chro, phase_advance_y_chro = _compute_chromatic_phase_advances(
+                self.chro_diff, bunch)
+            phase_advance_x += phase_advance_x_chro
+            phase_advance_y += phase_advance_y_chro
+
+        if self.adts_poly is not None:
             Jx = ((self.gamma0[0] * bunch["x"]**2) +
                   (2 * self.alpha0[0] * bunch["x"] * self["xp"]) +
                   (self.beta0[0] * bunch["xp"]**2))
             Jy = ((self.gamma0[1] * bunch["y"]**2) +
                   (2 * self.alpha0[1] * bunch["y"] * bunch["yp"]) +
                   (self.beta0[1] * bunch["yp"]**2))
-            phase_advance_x = (
-                2 * np.pi *
-                (self.tune_diff[0] + self.chro_diff[0] * bunch["delta"] +
-                 self.adts_poly[0](Jx) + self.adts_poly[2](Jy)))
-            phase_advance_y = (
-                2 * np.pi *
-                (self.tune_diff[1] + self.chro_diff[1] * bunch["delta"] +
-                 self.adts_poly[1](Jx) + self.adts_poly[3](Jy)))
-
-        # 6x6 matrix corresponding to (x, xp, delta, y, yp, delta)
-        matrix = np.zeros((6, 6, len(bunch)))
-
-        # Horizontal
-        matrix[0, 0, :] = np.sqrt(self.beta1[0] / self.beta0[0]) * (
-            np.cos(phase_advance_x) + self.alpha0[0] * np.sin(phase_advance_x))
-        matrix[0, 1, :] = np.sqrt(
-            self.beta0[0] * self.beta1[0]) * np.sin(phase_advance_x)
-        matrix[0, 2, :] = (self.dispersion1[0] -
-                           matrix[0, 0, :] * self.dispersion0[0] -
-                           matrix[0, 1, :] * self.dispersion0[1])
-        matrix[1, 0, :] = (
-            (self.alpha0[0] - self.alpha1[0]) * np.cos(phase_advance_x) -
-            (1 + self.alpha0[0] * self.alpha1[0]) *
-            np.sin(phase_advance_x)) / np.sqrt(self.beta0[0] * self.beta1[0])
-        matrix[1, 1, :] = np.sqrt(self.beta0[0] / self.beta1[0]) * (
-            np.cos(phase_advance_x) - self.alpha1[0] * np.sin(phase_advance_x))
-        matrix[1, 2, :] = (self.dispersion1[1] -
-                           matrix[1, 0, :] * self.dispersion0[0] -
-                           matrix[1, 1, :] * self.dispersion0[1])
-        matrix[2, 2, :] = 1
-
-        # Vertical
-        matrix[3, 3, :] = np.sqrt(self.beta1[1] / self.beta0[1]) * (
-            np.cos(phase_advance_y) + self.alpha0[1] * np.sin(phase_advance_y))
-        matrix[3, 4, :] = np.sqrt(
-            self.beta0[1] * self.beta1[1]) * np.sin(phase_advance_y)
-        matrix[3, 5, :] = (self.dispersion1[2] -
-                           matrix[3, 3, :] * self.dispersion0[2] -
-                           matrix[3, 4, :] * self.dispersion0[3])
-        matrix[4, 3, :] = (
-            (self.alpha0[1] - self.alpha1[1]) * np.cos(phase_advance_y) -
-            (1 + self.alpha0[1] * self.alpha1[1]) *
-            np.sin(phase_advance_y)) / np.sqrt(self.beta0[1] * self.beta1[1])
-        matrix[4, 4, :] = np.sqrt(self.beta0[1] / self.beta1[1]) * (
-            np.cos(phase_advance_y) - self.alpha1[1] * np.sin(phase_advance_y))
-        matrix[4, 5, :] = (self.dispersion1[3] -
-                           matrix[4, 3, :] * self.dispersion0[2] -
-                           matrix[4, 4, :] * self.dispersion0[3])
-        matrix[5, 5, :] = 1
-
-        x = (matrix[0, 0, :] * bunch["x"] + matrix[0, 1, :] * bunch["xp"] +
-             matrix[0, 2, :] * bunch["delta"])
-        xp = (matrix[1, 0, :] * bunch["x"] + matrix[1, 1, :] * bunch["xp"] +
-              matrix[1, 2, :] * bunch["delta"])
-        y = (matrix[3, 3, :] * bunch["y"] + matrix[3, 4, :] * bunch["yp"] +
-             matrix[3, 5, :] * bunch["delta"])
-        yp = (matrix[4, 3, :] * bunch["y"] + matrix[4, 4, :] * bunch["yp"] +
-              matrix[4, 5, :] * bunch["delta"])
-
-        bunch["x"] = x
-        bunch["xp"] = xp
-        bunch["y"] = y
-        bunch["yp"] = yp
+            phase_advance_x += (self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
+            phase_advance_x += (self.adts_poly[0](Jx) + self.adts_poly[2](Jy))
+
+        bunch['x'], bunch['xp'] = self._compute_new_coords(
+            bunch, phase_advance_x, 'x')
+        bunch['y'], bunch['yp'] = self._compute_new_coords(
+            bunch, phase_advance_y, 'y')
+
+
+class TransverseMap(TransverseMapSector):
+    """
+    Transverse map for a single turn in the synchrotron.
+
+    Parameters
+    ----------
+    ring : Synchrotron object
+    """
+
+    def __init__(self, ring):
+        super().__init__(ring, ring.optics.local_alpha, ring.optics.local_beta,
+                         ring.optics.local_dispersion, ring.optics.local_alpha,
+                         ring.optics.local_beta, ring.optics.local_dispersion,
+                         2 * np.pi * ring.tune, ring.chro, ring.adts)
 
 
 def transverse_map_sector_generator(ring, positions):
@@ -470,17 +391,17 @@ def transverse_map_sector_generator(ring, positions):
     if ring.optics.use_local_values:
         for i in range(N_sec):
             sectors.append(
-                TransverseMapSector(
-                    ring,
-                    ring.optics.local_alpha,
-                    ring.optics.local_beta,
-                    ring.optics.local_dispersion,
-                    ring.optics.local_alpha,
-                    ring.optics.local_beta,
-                    ring.optics.local_dispersion,
-                    ring.tune / N_sec,
-                    ring.chro / N_sec,
-                ))
+                TransverseMapSector(ring,
+                                    ring.optics.local_alpha,
+                                    ring.optics.local_beta,
+                                    ring.optics.local_dispersion,
+                                    ring.optics.local_alpha,
+                                    ring.optics.local_beta,
+                                    ring.optics.local_dispersion,
+                                    2 * np.pi * ring.tune / N_sec,
+                                    ring.chro / N_sec,
+                                    adts=ring.adts /
+                                    N_sec if ring.adts else None))
     else:
         import at