From fe16070b23b5a6e250c6fbc14e7fbafcb3b2b7a4 Mon Sep 17 00:00:00 2001
From: Alexis Gamelin <alexis.gamelin@synchrotron-soleil.fr>
Date: Fri, 5 Apr 2024 10:53:13 +0200
Subject: [PATCH] Improve Sweep element

Sweep can now be applied only to a single bunch of a beam and on different planes.
Add option to save higher order spectrums in  BunchSpectrumMonitor.
Modify plot_bunchspectrum and streak_bunchspectrum to plot higher order spectrums.
Add docstrings for new features.
---
 mbtrack2/tracking/__init__.py          |   2 +-
 mbtrack2/tracking/excite.py            | 115 +++++++++++++++++++++++++
 mbtrack2/tracking/monitors/monitors.py |  71 ++++++++-------
 mbtrack2/tracking/monitors/plotting.py |  18 +---
 mbtrack2/tracking/particles.py         |  12 ++-
 mbtrack2/tracking/sweep.py             |  42 ---------
 6 files changed, 171 insertions(+), 89 deletions(-)
 create mode 100644 mbtrack2/tracking/excite.py
 delete mode 100644 mbtrack2/tracking/sweep.py

diff --git a/mbtrack2/tracking/__init__.py b/mbtrack2/tracking/__init__.py
index a4ea3ff..62e879f 100644
--- a/mbtrack2/tracking/__init__.py
+++ b/mbtrack2/tracking/__init__.py
@@ -31,4 +31,4 @@ from mbtrack2.tracking.wakepotential import (
     LongRangeResistiveWall,
     WakePotential,
 )
-from mbtrack2.tracking.sweep import Sweep
\ No newline at end of file
+from mbtrack2.tracking.excite import Sweep
\ No newline at end of file
diff --git a/mbtrack2/tracking/excite.py b/mbtrack2/tracking/excite.py
new file mode 100644
index 0000000..61075d5
--- /dev/null
+++ b/mbtrack2/tracking/excite.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+"""
+Module to deal with different kinds of beam excitation.
+"""
+import numpy as np
+import matplotlib.pyplot as plt
+from mbtrack2.tracking.element import Element
+from mbtrack2.tracking.particles import Bunch, Beam
+from scipy.signal import chirp
+
+class Sweep(Element):
+    """
+    Element which excite the beam in between two frequencies, i.e. apply 
+    frequency sweep (chirp) on all or a given bunch in the chosen plane.
+   
+    If applied to a full beam, the excitation is the same (and at the same time)
+    for all bunches, so it drives a growth of coupled bunch mode 0.
+   
+    Parameters
+    ----------
+    ring : Synchrotron
+        Synchrotron object.
+    f0 : float
+        Initial frequency of the sweep in [Hz].
+    f1 : float
+        Final frequency of the sweep in [Hz].
+    t1 : float
+        Time duration of the sweep in [s].
+    level : float
+        Kick level to apply in [V].
+    plane : "x", "y" or "tau"
+        Plane on which to apply the kick.
+    bunch_to_sweep : int, optional
+        Bunch number to which the sweeping is applied.
+        If None, the sweeping is applied for all bunches.
+        Default is None.
+        
+    Methods
+    -------
+    track(bunch_or_beam)
+        Tracking method for the element.
+    plot()
+        Plot the sweep voltage applied.
+    
+    """
+    def __init__(self, ring, f0, f1, t1, level, plane, bunch_to_sweep=None):
+        self.ring = ring
+        self.t = np.arange(0, t1, ring.T0)
+        self.N = len(self.t)
+        self.count = 0
+        self.level = level
+        self.sweep = chirp(self.t, f0, t1, f1)
+        if plane == "x":
+            self.apply = "xp"
+        elif plane == "y":
+            self.apply = "yp"
+        elif plane == "tau":
+            self.apply = "delta"
+        else:
+            raise ValueError("plane should be 'x', 'y' or 'tau'.")
+        self.bunch_to_sweep = bunch_to_sweep
+            
+    def track(self, bunch_or_beam):
+        """
+        Tracking method for this element.
+        
+        Parameters
+        ----------
+        bunch_or_beam : Bunch or Beam
+        """
+        if isinstance(bunch_or_beam, Bunch):
+            bunch = bunch_or_beam
+            self._track_bunch(bunch)
+        elif isinstance(bunch_or_beam, Beam):
+            beam = bunch_or_beam
+            if (beam.mpi_switch == True):
+                if self.bunch_to_sweep is not None:
+                    if beam.mpi.bunch_num == self.bunch_to_sweep:
+                        self._track_bunch(beam[beam.mpi.bunch_num])
+                else:
+                    self._track_bunch(beam[beam.mpi.bunch_num])
+            else:
+                if self.bunch_to_sweep is not None:
+                    self._track_bunch(beam[self.bunch_to_sweep])
+                else:
+                    for bunch in beam.not_empty:
+                        self._track_bunch(bunch, False)
+                    self.count += 1
+                    if self.count >= self.N:
+                        self.count = 0
+        else:
+            raise TypeError("bunch_or_beam should be a Beam or Bunch object.")
+        
+    def _track_bunch(self, bunch, count_step=True):
+        """
+        Tracking method for a bunch for this element.
+        
+        Parameters
+        ----------
+        bunch : Bunch
+        """
+        sweep_val = self.sweep[self.count]
+        bunch[self.apply] += self.level / self.ring.E0 * sweep_val
+        if count_step:
+            self.count += 1
+            if self.count >= self.N:
+                self.count = 0
+
+    def plot(self):
+        """Plot the sweep voltage applied."""
+        fig, ax = plt.subplots()
+        ax.plot(self.t, self.sweep)
+        ax.xlabel("Time [s]")
+        ax.ylabel("Sweep voltage [V]")
+        return fig
\ No newline at end of file
diff --git a/mbtrack2/tracking/monitors/monitors.py b/mbtrack2/tracking/monitors/monitors.py
index f5abcbb..dfa9f54 100644
--- a/mbtrack2/tracking/monitors/monitors.py
+++ b/mbtrack2/tracking/monitors/monitors.py
@@ -1012,6 +1012,10 @@ class BunchSpectrumMonitor(Monitor):
         If True, open the HDF5 file in parallel mode, which is needed to
         allow several cores to write in the same file at the same time.
         If False, open the HDF5 file in standard mode.
+    higher_orders : bool, optional
+        If True, save coherent spectrums for higher order moments (FFT of the 
+        std, skew and kurtosis components).
+        Default is False.
         
     Attributes
     ----------
@@ -1040,12 +1044,15 @@ class BunchSpectrumMonitor(Monitor):
                  dim="all",
                  n_fft=None,
                  file_name=None,
-                 mpi_mode=False):
+                 mpi_mode=False,
+                 higher_orders=False):
 
         if n_fft is None:
             self.n_fft = int(save_every)
         else:
             self.n_fft = int(n_fft)
+            
+        self.higher_orders = higher_orders
 
         self.sample_size = int(sample_size)
         self.store_dict = {"x": 0, "y": 1, "tau": 2}
@@ -1083,21 +1090,23 @@ class BunchSpectrumMonitor(Monitor):
         dict_buffer = {
             "incoherent": (3, self.n_fft // 2 + 1, buffer_size),
             "coherent": (3, self.n_fft // 2 + 1, buffer_size),
-            "coherent_q":(3, self.n_fft//2+1, buffer_size),
-            "coherent_s":(3, self.n_fft//2+1, buffer_size),
-            "coherent_o":(3, self.n_fft//2+1, buffer_size),
             "mean_incoherent": (3, buffer_size),
             "std_incoherent": (3, buffer_size)
         }
         dict_file = {
             "incoherent": (3, self.n_fft // 2 + 1, total_size),
             "coherent": (3, self.n_fft // 2 + 1, total_size),
-            "coherent_q":(3, self.n_fft//2+1, total_size),
-            "coherent_s":(3, self.n_fft//2+1, total_size),
-            "coherent_o":(3, self.n_fft//2+1, total_size),
             "mean_incoherent": (3, total_size),
             "std_incoherent": (3, total_size)
         }
+        
+        if self.higher_orders:
+            dict_buffer["coherent_q"] = (3, self.n_fft//2+1, buffer_size)
+            dict_buffer["coherent_s"] = (3, self.n_fft//2+1, buffer_size)
+            dict_buffer["coherent_o"] = (3, self.n_fft//2+1, buffer_size)
+            dict_file["coherent_q"] = (3, self.n_fft//2+1, total_size)
+            dict_file["coherent_s"] = (3, self.n_fft//2+1, total_size)
+            dict_file["coherent_o"] = (3, self.n_fft//2+1, total_size)
 
         self.monitor_init(group_name, save_every, buffer_size, total_size,
                           dict_buffer, dict_file, file_name, mpi_mode)
@@ -1110,9 +1119,10 @@ class BunchSpectrumMonitor(Monitor):
         self.positions = np.zeros(
             (self.size_list, self.sample_size, self.save_every + 1))
         self.mean = np.zeros((self.size_list, self.save_every + 1))
-        self.std = np.zeros((self.size_list, self.save_every+1))
-        self.skew = np.zeros((self.size_list, self.save_every+1))
-        self.kurtosis = np.zeros((self.size_list, self.save_every+1))
+        if self.higher_orders:
+            self.std = np.zeros((self.size_list, self.save_every+1))
+            self.skew = np.zeros((self.size_list, self.save_every+1))
+            self.kurtosis = np.zeros((self.size_list, self.save_every+1))
 
         index = np.arange(0, int(mp_number))
         self.index_sample = sorted(random.sample(list(index),
@@ -1120,9 +1130,10 @@ class BunchSpectrumMonitor(Monitor):
 
         self.incoherent = np.zeros((3, self.n_fft // 2 + 1, self.buffer_size))
         self.coherent = np.zeros((3, self.n_fft // 2 + 1, self.buffer_size))
-        self.coherent_q = np.zeros((3, self.n_fft//2+1, self.buffer_size))
-        self.coherent_s = np.zeros((3, self.n_fft//2+1, self.buffer_size))
-        self.coherent_o = np.zeros((3, self.n_fft//2+1, self.buffer_size))
+        if self.higher_orders:
+            self.coherent_q = np.zeros((3, self.n_fft//2+1, self.buffer_size))
+            self.coherent_s = np.zeros((3, self.n_fft//2+1, self.buffer_size))
+            self.coherent_o = np.zeros((3, self.n_fft//2+1, self.buffer_size))
 
         self.file[self.group_name].create_dataset("freq",
                                                   data=self.frequency_samples)
@@ -1184,9 +1195,10 @@ class BunchSpectrumMonitor(Monitor):
                 self.positions[value, :, self.save_count] = np.nan
 
             self.mean[:, self.save_count] = bunch.mean[self.mean_index]
-            self.std[:, self.save_count] = bunch.std[self.mean_index]
-            self.skew[:, self.save_count] = bunch.skew[self.mean_index]
-            self.kurtosis[:, self.save_count] = bunch.kurtosis[self.mean_index]
+            if self.higher_orders:
+                self.std[:, self.save_count] = bunch.std[self.mean_index]
+                self.skew[:, self.save_count] = bunch.skew[self.mean_index]
+                self.kurtosis[:, self.save_count] = bunch.kurtosis[self.mean_index]
 
             self.save_count += 1
 
@@ -1216,10 +1228,10 @@ class BunchSpectrumMonitor(Monitor):
             self.coherent[self.store_dict[key], :,
                           self.buffer_count] = self.get_coherent_spectrum(
                               self.mean[value])
-            self.coherent_q[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.std[value])
-            self.coherent_s[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.skew[value])
-            self.coherent_o[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.kurtosis[value])
-
+            if self.higher_orders:
+                self.coherent_q[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.std[value])
+                self.coherent_s[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.skew[value])
+                self.coherent_o[self.store_dict[key],:,self.buffer_count] = self.get_coherent_spectrum(self.kurtosis[value])
 
         self.buffer_count += 1
 
@@ -1255,15 +1267,16 @@ class BunchSpectrumMonitor(Monitor):
                                          self.buffer_size:(self.write_count +
                                                            1) *
                                          self.buffer_size] = self.coherent
-        self.file[self.group_name]["coherent_q"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.coherent_q
-        self.file[self.group_name]["coherent_s"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.coherent_s
-        self.file[self.group_name]["coherent_o"][:,:, 
-                self.write_count * self.buffer_size:(self.write_count+1) * 
-                self.buffer_size] = self.coherent_o
+        if self.higher_orders: 
+            self.file[self.group_name]["coherent_q"][:,:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.coherent_q
+            self.file[self.group_name]["coherent_s"][:,:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.coherent_s
+            self.file[self.group_name]["coherent_o"][:,:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.coherent_o
 
         self.file.flush()
         self.write_count += 1
diff --git a/mbtrack2/tracking/monitors/plotting.py b/mbtrack2/tracking/monitors/plotting.py
index e407151..a02842b 100644
--- a/mbtrack2/tracking/monitors/plotting.py
+++ b/mbtrack2/tracking/monitors/plotting.py
@@ -780,7 +780,8 @@ def plot_bunchspectrum(filenames,
     bunch_number : int or list of int
         Bunch to plot. This has to be identical to 'bunch_number' parameter in 
         'BunchSpectrumMonitor' object.
-    dataset : {"mean_incoherent", "coherent", "incoherent"}
+    dataset : {"mean_incoherent", "coherent", "incoherent", "coherent_q", 
+               "coherent_s", "coherent_o"}
         HDF5 file's dataset to be plotted. 
         The default is "incoherent".
     dim :  {"x","y","tau"}, optional
@@ -834,7 +835,7 @@ def plot_bunchspectrum(filenames,
             ax.errorbar(time, y_var, y_err)
             xlabel = "Turn number"
             ylabel = "Mean incoherent frequency [Hz]"
-        elif dataset == "incoherent" or dataset == "coherent":
+        else:
 
             if turns is None:
                 turn_index = np.where(time == time)[0]
@@ -862,10 +863,6 @@ def plot_bunchspectrum(filenames,
                 ax.set_yscale('log')
 
             ylabel = "FFT amplitude [a.u.]"
-            if dataset == "incoherent":
-                ax.set_title("Incoherent spectrum")
-            elif dataset == "coherent":
-                ax.set_title("Coherent spectrum")
 
         ax.set_xlabel(xlabel)
         ax.set_ylabel(ylabel)
@@ -897,7 +894,7 @@ def streak_bunchspectrum(filename,
     bunch_number : int
         Bunch to plot. This has to be identical to 'bunch_number' parameter in 
         'BunchSpectrumMonitor' object.
-    dataset : {"coherent", "incoherent"}
+    dataset : {coherent", "incoherent", "coherent_q", "coherent_s", "coherent_o"}
         HDF5 file's dataset to be plotted. 
         The default is "incoherent".
     dim :  {"x","y","tau"}, optional
@@ -921,13 +918,11 @@ def streak_bunchspectrum(filename,
         If not None, should be array like in the form [ymin, ymax] where ymin 
         and ymax are the minimum and maxmimum values used in the y axis.
         
-
     Returns
     -------
     fig : Figure
 
     """
-
     file = hp.File(filename, "r")
     group = file["BunchSpectrum_{0}".format(bunch_number)]
 
@@ -984,11 +979,6 @@ def streak_bunchspectrum(filename,
         ylabel = ""
 
     fig, ax = plt.subplots()
-    if dataset == "incoherent":
-        ax.set_title("Incoherent spectrum")
-    elif dataset == "coherent":
-        ax.set_title("Coherent spectrum")
-
     cmap = mpl.cm.inferno  # sequential
     c = ax.imshow(data.T,
                   cmap=cmap,
diff --git a/mbtrack2/tracking/particles.py b/mbtrack2/tracking/particles.py
index bff2263..e9b673b 100644
--- a/mbtrack2/tracking/particles.py
+++ b/mbtrack2/tracking/particles.py
@@ -97,7 +97,13 @@ class Bunch:
         Standard deviation of the position of alive particles for each 
         coordinates.
     emit : array of shape (3,)
-        Bunch emittance for each plane [1]. !!! -> Correct for long ?
+        Bunch emittance for each plane [1].
+    skew : array of shape (6,)
+        Skew (3rd moment) of the position of alive particles for each 
+        coordinates.
+    kurtosis : array of shape (6,)
+        Kurtosis (4th moment) of the position of alive particles for each 
+        coordinates.
         
     Methods
     -------
@@ -248,7 +254,7 @@ class Bunch:
     @property
     def skew(self):
         """
-        Return the standard deviation of the position of alive 
+        Return the skew (3rd moment) of the position of alive 
         particles for each coordinates.
         """
         skew = [[moment(self[name],3)] for name in self]
@@ -257,7 +263,7 @@ class Bunch:
     @property
     def kurtosis(self):
         """
-        Return the standard deviation of the position of alive 
+        Return the kurtosis (4th moment) of the position of alive 
         particles for each coordinates.
         """
         kurtosis = [[moment(self[name],4)] for name in self]
diff --git a/mbtrack2/tracking/sweep.py b/mbtrack2/tracking/sweep.py
deleted file mode 100644
index 61d0c3d..0000000
--- a/mbtrack2/tracking/sweep.py
+++ /dev/null
@@ -1,42 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-"""
-Created on Fri Mar 29 15:12:28 2024
-
-@author: gamelina
-"""
-import numpy as np
-import matplotlib.pyplot as plt
-from mbtrack2.tracking.element import Element
-from scipy.signal import chirp
-
-class Sweep(Element):
-    def __init__(self, ring, f0, f1, t1, level):
-        self.ring = ring
-        self.t = np.arange(0, t1, ring.T0)
-        self.N = len(self.t)
-        self.count = 0
-        self.level = level
-        self.sweep = chirp(self.t, f0, t1, f1)
-
-    @Element.parallel
-    def track(self, bunch):
-        """
-        Tracking method for the element.
-        No bunch to bunch interaction, so written for Bunch objects and
-        @Element.parallel is used to handle Beam objects.
-        
-        Parameters
-        ----------
-        bunch : Bunch or Beam object
-        """
-        sweep_val = self.sweep[self.count]
-        bunch["delta"] += self.level / self.ring.E0 * sweep_val
-        self.count += 1
-        if self.count >= self.N:
-            self.count = 0
-
-    def plot(self):
-        fig, ax = plt.subplots()
-        ax.plot(self.t, self.sweep)
-        return fig
\ No newline at end of file
-- 
GitLab