From d73a013a0c164bd36913112c313e4b734548ba1f Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Tue, 29 Sep 2020 11:36:39 +0200
Subject: [PATCH] Add WakePotentialMonitor and plotting

---
 tracking/monitors/__init__.py |   8 ++-
 tracking/monitors/monitors.py | 129 +++++++++++++++++++++++++++++++++-
 tracking/monitors/plotting.py | 108 ++++++++++++++++++++++++++++
 3 files changed, 240 insertions(+), 5 deletions(-)

diff --git a/tracking/monitors/__init__.py b/tracking/monitors/__init__.py
index 028472c..2f94fe6 100644
--- a/tracking/monitors/__init__.py
+++ b/tracking/monitors/__init__.py
@@ -8,6 +8,10 @@ Created on Tue Jan 14 18:11:33 2020
 from mbtrack2.tracking.monitors.monitors import (Monitor, BunchMonitor, 
                                                  PhaseSpaceMonitor,
                                                  BeamMonitor,
-                                                 ProfileMonitor)
+                                                 ProfileMonitor,
+                                                 WakePotentialMonitor)
 from mbtrack2.tracking.monitors.plotting import (plot_bunchdata, 
-                                                 plot_phasespacedata)
\ No newline at end of file
+                                                 plot_phasespacedata,
+                                                 plot_profiledata,
+                                                 plot_beamdata,
+                                                 plot_wakedata)
\ No newline at end of file
diff --git a/tracking/monitors/monitors.py b/tracking/monitors/monitors.py
index c7c6ac8..0382055 100644
--- a/tracking/monitors/monitors.py
+++ b/tracking/monitors/monitors.py
@@ -282,8 +282,7 @@ class BunchMonitor(Monitor):
         object_to_save : Bunch or Beam object
         """        
         self.track_bunch_data(object_to_save)
-
-            
+        
 class PhaseSpaceMonitor(Monitor):
     """
     Monitor a single bunch and save the full phase space.
@@ -626,4 +625,128 @@ class ProfileMonitor(Monitor):
         ----------
         object_to_save : Bunch or Beam object
         """        
-        self.track_bunch_data(object_to_save)
\ No newline at end of file
+        self.track_bunch_data(object_to_save)
+        
+class WakePotentialMonitor(Monitor):
+    """
+    Monitor the wake potential from a single bunch and save attributes (tau, 
+    ...).
+    
+    Parameters
+    ----------
+    bunch_number : int
+        Bunch to monitor.
+    wake_types : str or list of str
+        Wake types to save: "Wlong, "Wxdip", ...
+    n_bin : int
+        Number of bin used for the bunch slicing.
+    file_name : string, optional
+        Name of the HDF5 where the data will be stored. Must be specified
+        the first time a subclass of Monitor is instancied and must be None
+        the following times.
+    save_every : int or float, optional
+        Set the frequency of the save. The data is saved every save_every 
+        call of the montior.
+    buffer_size : int or float, optional
+        Size of the save buffer.
+    total_size : int or float, optional
+        Total size of the save. The following relationships between the 
+        parameters must exist: 
+            total_size % buffer_size == 0
+            number of call to track / save_every == total_size
+    mpi_mode : bool, optional
+        If True, open the HDF5 file in parallel mode, which is needed to
+        allow several cores to write in the same file at the same time.
+        If False, open the HDF5 file in standard mode.
+
+    Methods
+    -------
+    track(wake_potential_to_save)
+        Save data.
+    """
+    
+    def __init__(self, bunch_number, wake_types, n_bin, file_name=None, 
+                 save_every=5, buffer_size=500, total_size=2e4, mpi_mode=True):
+        
+        self.bunch_number = bunch_number
+        group_name = "WakePotentialData_" + str(self.bunch_number)
+        
+        if isinstance(wake_types, str):
+            self.wake_types = [wake_types]
+        else:
+            self.wake_types = wake_types
+            
+        self.n_bin = n_bin
+        
+        dict_buffer = {}
+        dict_file = {}
+        dict_buffer.update({"tau" : (self.n_bin, buffer_size)})
+        dict_file.update({"tau" : (self.n_bin, total_size)})
+        dict_buffer.update({"rho" : (self.n_bin, buffer_size)})
+        dict_file.update({"rho" : (self.n_bin, total_size)})
+        for index, dim in enumerate(self.wake_types):
+            dict_buffer.update({dim : (self.n_bin, buffer_size)})
+            dict_file.update({dim : (self.n_bin, total_size)})
+
+        self.monitor_init(group_name, save_every, buffer_size, total_size,
+                          dict_buffer, dict_file, file_name, mpi_mode)
+        
+        self.dict_buffer = dict_buffer
+        self.dict_file = dict_file
+        
+    def to_buffer(self, wp):
+        """
+        Save data to buffer.
+        
+        Parameters
+        ----------
+        wp : WakePotential object
+        """
+
+        self.time[self.buffer_count] = self.track_count
+        self.tau[:, self.buffer_count] = wp.tau
+        self.rho[:, self.buffer_count] = wp.rho
+        for index, dim in enumerate(self.wake_types):
+            self.__getattribute__(dim)[:, self.buffer_count] = wp.__getattribute__(dim)
+        
+        self.buffer_count += 1
+        
+        if self.buffer_count == self.buffer_size:
+            self.write()
+            self.buffer_count = 0
+            
+    def write(self):
+        """Write data from buffer to the HDF5 file."""
+        
+        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
+                    self.write_count+1)*self.buffer_size] = self.time
+
+        self.file[self.group_name]["tau"][:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.tau
+        
+        self.file[self.group_name]["rho"][:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.rho
+        
+        for dim in self.wake_types:
+            self.file[self.group_name][dim][:, 
+                    self.write_count * self.buffer_size:(self.write_count+1) * 
+                    self.buffer_size] = self.__getattribute__(dim)
+            
+        self.file.flush()
+        self.write_count += 1
+                    
+    def track(self, wake_potential_to_save):
+        """
+        Save data.
+        
+        Parameters
+        ----------
+        object_to_save : WakePotential object
+        """        
+        if self.track_count % self.save_every == 0:
+            self.to_buffer(wake_potential_to_save)
+        self.track_count += 1
+
+            
\ No newline at end of file
diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index 2738321..7f74c63 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -373,4 +373,112 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
         return fig
     elif streak_plot is True:
         return fig2
+    
+def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
+                     stop=None, step=None, profile_plot=False, streak_plot=True):
+    """
+    Plot data recorded by WakePotentialMonitor
+
+    Parameters
+    ----------
+    filename : str
+        Name of the HDF5 file that contains the data.
+    bunch_number : int
+        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
+        'WakePotentialMonitor' object.
+    wake_type : str, optional
+        Wake type to plot: "Wlong", "Wxdip", ... Can also be "rho" to show 
+        bunch profile.
+    start : int, optional
+        First turn to plot. The default is 0.
+    stop : int, optional
+        Last turn to plot. If None, the last turn of the record is selected.
+    step : int, optional
+        Plotting step. This has to be divisible by 'save_every' parameter in
+        'WakePotentialMonitor' object, i.e. step % save_every == 0. If None, 
+        step is equivalent to save_every.
+    profile_plot : bool, optional
+        If Ture, wake potential profile plot is plotted.
+    streak_plot : bool, optional
+        If True, strek plot is plotted.
+
+    Returns
+    -------
+    fig : Figure
+        Figure object with the plot on it.
+
+    """
+    
+    file = hp.File(filename, "r")
+    path = file['WakePotentialData_{0}'.format(bunch_number)]
+    
+    if stop is None:
+        stop = path['time'][-1]
+    elif stop not in path['time']:
+        raise ValueError("stop not found. Choose from {0}"
+                         .format(path['time'][:]))
+ 
+    if start not in path['time']:
+        raise ValueError("start not found. Choose from {0}"
+                         .format(path['time'][:]))
+    
+    save_every = path['time'][1] - path['time'][0]
+    
+    if step is None:
+        step = save_every
+    
+    if step % save_every != 0:
+        raise ValueError("step must be divisible by the recording step "
+                         "which is {0}.".format(save_every))
+    
+    dimension_dict = {"Wlong":0, "Wxdip":1, "Wydip":2, "Wxquad":3, "Wyquad":4, "rho":5}
+    scale = [1e-12, 1e-12, 1e-12, 1e-15, 1e-15, 1]
+    label = ["$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$", "$W_{p,x}^Q (V/pC/mm)$",
+             "$W_{p,y}^Q (V/pC/mm)$", "$\\rho (a.u.)$"]
+    
+    num = int((stop - start)/step)
+    n_bin = len(path[wake_type][:,0])
+    
+    start_index = np.where(path['time'][:] == start)[0][0]
+    
+    x_var = np.zeros((num+1,n_bin))
+    turn_index_array = np.zeros((num+1,))
+    for i in range(num+1):
+        turn_index = start_index + i * step / save_every 
+        turn_index_array[i] = turn_index
+        # construct an array of bin mids
+        x_var[i,:] = path["tau"][:,turn_index]
+        
+    if profile_plot is True:
+        fig, ax = plt.subplots()
+        for i in range(num+1):
+            ax.plot(x_var[i]*1e12,
+                    path[wake_type][:,turn_index_array[i]]*scale[dimension_dict[wake_type]], 
+                    label="turn {0}".format(path['time'][turn_index_array[i]]))
+        ax.set_xlabel("$\\tau$ (ps)")
+        ax.set_ylabel(label[dimension_dict[wake_type]])         
+        ax.legend()
+            
+    if streak_plot is True:
+        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
+        y_var = np.ones((num+1,n_bin)) * turn
+        z_var = np.transpose(path[wake_type][:,turn_index_array])
+        fig2, ax2 = plt.subplots()
+        cmap = mpl.cm.cool
+        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
+                       extent=[x_var.min()*1e12,
+                               x_var.max()*1e12,
+                               y_var.min(),y_var.max()])
+        ax2.set_xlabel("$\\tau$ (ps)")
+        ax2.set_ylabel("Number of turns")
+        cbar = fig2.colorbar(c, ax=ax2)
+        cbar.set_label(label[dimension_dict[wake_type]]) 
+
+    file.close()
+    if profile_plot is True and streak_plot is True:
+        return fig, fig2
+    elif profile_plot is True:
+        return fig
+    elif streak_plot is True:
+        return fig2
     
\ No newline at end of file
-- 
GitLab