From f3a40eb111a79453fa426dca819e19160bad03f8 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <alexis.gamelin@synchrotron-soleil.fr>
Date: Wed, 2 Mar 2022 18:13:19 +0100
Subject: [PATCH] Plotting function improvements

Fix error on plot_beamdata.
Add option to specify ylim to streak_bunchspectrum and to streak_beamspectrum.
---
 tracking/monitors/plotting.py | 41 ++++++++++++++++++++++++++---------
 1 file changed, 31 insertions(+), 10 deletions(-)

diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index e481411..dfbe1bb 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -100,7 +100,7 @@ def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean",
             if turn is None:
                 idx = -1
             else:
-                idx = np.where(time == int(turn))
+                idx = np.where(time == int(turn))[0]
                 if (idx.size == 0):
                     raise ValueError("Turn is not valid.")
             
@@ -748,7 +748,7 @@ def plot_bunchspectrum(filenames, bunch_number, dataset="incoherent", dim="tau",
 
 def streak_bunchspectrum(filename, bunch_number, dataset="incoherent", 
                          dim="tau", fs=None, log_scale=True, fmin=None, 
-                         fmax=None, turns=None, norm=False):
+                         fmax=None, turns=None, norm=False, ylim=None):
     """
     Plot 3D data recorded by the BunchSpectrumMonitor.
 
@@ -779,6 +779,10 @@ def streak_bunchspectrum(filename, bunch_number, dataset="incoherent",
     norm : bool, optional
         If True, normalise the data of each spectrum by its geometric mean.
         The default is False.
+    ylim : array, optional
+        If not None, should be array like in the form [ymin, ymax] where ymin 
+        and ymax are the minimum and maxmimum values used in the y axis.
+        
 
     Returns
     -------
@@ -795,8 +799,12 @@ def streak_bunchspectrum(filename, bunch_number, dataset="incoherent",
     
     if turns is None:
         turn_index = np.where(time == time)[0]
-        tmin = time.min()
-        tmax = time.max()
+        if ylim is None:
+            tmin = time.min()
+            tmax = time.max()
+        else:
+            tmin = ylim[0]
+            tmax = ylim[1]
     else:
         tmin = turns.min()
         tmax = turns.max()
@@ -831,8 +839,11 @@ def streak_bunchspectrum(filename, bunch_number, dataset="incoherent",
     
     if norm is True:
         data = data/gmean(data)
-        
-    ylabel = "Turn number"
+    
+    if ylim is None:
+        ylabel = "Turn number"
+    else:
+        ylabel = ""
     
     fig, ax = plt.subplots()
     if dataset == "incoherent":
@@ -937,7 +948,7 @@ def plot_beamspectrum(filenames, dim="tau", turns=None, f0=None,
     return fig
 
 def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None, 
-                         fmax=None, turns=None, norm=False):
+                         fmax=None, turns=None, norm=False, ylim=None):
     """
     Plot 3D data recorded by the BeamSpectrumMonitor.
 
@@ -962,6 +973,9 @@ def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None,
     norm : bool, optional
         If True, normalise the data of each spectrum by its geometric mean.
         The default is False.
+    ylim : array, optional
+        If not None, should be array like in the form [ymin, ymax] where ymin 
+        and ymax are the minimum and maxmimum values used in the y axis.
 
     Returns
     -------
@@ -978,8 +992,12 @@ def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None,
     
     if turns is None:
         turn_index = np.where(time == time)[0]
-        tmin = time.min()
-        tmax = time.max()
+        if ylim is None:
+            tmin = time.min()
+            tmax = time.max()
+        else:
+            tmin = ylim[0]
+            tmax = ylim[1]
     else:
         tmin = turns.min()
         tmax = turns.max()
@@ -1015,7 +1033,10 @@ def streak_beamspectrum(filename, dim="tau", f0=None, log_scale=True, fmin=None,
     if norm is True:
         data = data/gmean(data)
         
-    ylabel = "Turn number"
+    if ylim is None:
+        ylabel = "Turn number"
+    else:
+        ylabel = ""
     
     fig, ax = plt.subplots()
     ax.set_title("Beam coherent spectrum")   
-- 
GitLab