From c05ac48f1d8fd385bd6bf3bd834a1c82d3b9c2aa Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Thu, 31 Mar 2022 09:49:13 +0200
Subject: [PATCH] [Fix] plot_phasespacedata and plot_profiledata

Fix plot_phasespacedata and plot_profiledata
Change color map for streak plots for plot_profiledata, plot_wakedata, plot_cavitydata
---
 tracking/monitors/plotting.py | 22 ++++++++++++----------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index ab53a90..e41904e 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -366,9 +366,10 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
     
     path = file[group][dataset]
     mp_number = path[:,0,0].size
-    
+
     if only_alive is True:
-        index = np.where(file[group]["alive"][:,turn_index])[0]
+        data = np.array(file[group]["alive"])
+        index = np.where(data[:,turn_index])[0]
     else:
         index = np.arange(mp_number)
         
@@ -461,11 +462,11 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
     n_bin = len(data[:,0])
     
     start_index = np.where(time[:] == start)[0][0]
-    
+
     x_var = np.zeros((num+1,n_bin))
-    turn_index_array = np.zeros((num+1,))
+    turn_index_array = np.zeros((num+1,), dtype=int)
     for i in range(num+1):
-        turn_index = start_index + i * step / save_every 
+        turn_index = int(start_index + i * step / save_every)
         turn_index_array[i] = turn_index
         # construct an array of bin mids
         x_var[i,:] = l_bound[:,turn_index]
@@ -485,7 +486,7 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
         y_var = np.ones((num+1,n_bin)) * turn
         z_var = np.transpose(data[:,turn_index_array])
         fig2, ax2 = plt.subplots()
-        cmap = mpl.cm.cool
+        cmap = mpl.cm.inferno # sequential
         c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                        extent=[x_var.min()*scale[dimension_dict[dimension]],
                                x_var.max()*scale[dimension_dict[dimension]],
@@ -576,16 +577,17 @@ def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
         dimension_dict = {wake_type:0}
         scale = [1]
         label = ["$\\rho$ (a.u.)"]
-        
+        cmap = mpl.cm.inferno # sequential
     elif dipole == True:
         tau_name = "tau_" + wake_type
         wake_type = "dipole_" + wake_type
         dimension_dict = {wake_type:0}
         scale = [1]
         label = ["Dipole moment (m)"]
-        
+        cmap = mpl.cm.coolwarm # diverging
     else:
         tau_name = "tau_" + wake_type
+        cmap = mpl.cm.coolwarm # diverging
         
     data = np.array(path[wake_type])
         
@@ -617,7 +619,6 @@ def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
         y_var = np.ones((num+1,n_bin)) * turn
         z_var = np.transpose(data[:,turn_index_array]*scale[dimension_dict[wake_type]])
         fig2, ax2 = plt.subplots()
-        cmap = mpl.cm.cool
         c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                        extent=[x_var.min()*1e12,
                                x_var.max()*1e12,
@@ -1154,12 +1155,13 @@ def plot_cavitydata(filename, cavity_name, phasor="cavity",
         if plot_type == "streak_volt":
             data = np.transpose(np.abs(cavity_data["cavity_phasor_record"][:,:])*1e-6)
             ylabel = labels[ph[phasor]] + " voltage [MV]"
+            cmap = mpl.cm.coolwarm # diverging
         elif plot_type == "streak_phase":
             data = np.transpose(np.angle(cavity_data["cavity_phasor_record"][:,:]))
             ylabel = labels[ph[phasor]] + " phase [rad]"
+            cmap = mpl.cm.coolwarm # diverging
             
         fig, ax = plt.subplots()
-        cmap = mpl.cm.cool
         c = ax.imshow(data, cmap=cmap, origin='lower' , aspect='auto')
         if cm_lim is not None:
             c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
-- 
GitLab