From bd29899689917e67389c4b8232b3de57348d84c6 Mon Sep 17 00:00:00 2001
From: Watanyu Foosang <watanyu.f@gmail.com>
Date: Fri, 12 Mar 2021 16:59:03 +0100
Subject: [PATCH] Improve plot_beamdata

Remove for-loops inside the function.
---
 tracking/monitors/plotting.py | 65 ++++++++++++-----------------------
 1 file changed, 22 insertions(+), 43 deletions(-)

diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index 63ed6bf..f1531e0 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -14,7 +14,7 @@ import seaborn as sns
 import h5py as hp
 import random
 
-def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
+def plot_beamdata(filename, dataset, dimension=None, stat_var=None, x_var="time"):
     """
     Plot data recorded by BeamMonitor.
 
@@ -24,14 +24,14 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
         Name of the HDF5 file that contains the data.
     dataset : {"current","emit","mean","std"}
         HDF5 file's dataset to be plotted.
-    option : str, optional
-        If dataset is "emit", "mean", or "std", the variable name to be plotted
-        needs to be specified :
-            for "emit", option = {"x","y","s"}
-            for "mean" and "std", option = {"x","xp","y","yp","tau","delta"}
+    dimension : str, optional
+        The dimension of the dataset to plot. Use "None" for "current",
+        otherwise use the following : 
+            for "emit", dimension = {"x","y","s"},
+            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
     stat_var : {"mean", "std"}, optional
-        Statistical value of option. Except when dataset = "current", stat_var
-        needs to be specified. The default is None.
+        Statistical value of the dimension. Unless dataset = "current", stat_var
+        needs to be specified.
     x_var : str, optional
         Variable to be plotted on the horizontal axis. The default is "time".
         
@@ -46,64 +46,43 @@ def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
     path = file["Beam"]
     
     if dataset == "current":
-        total_current = []
-        for i in range (len(path["time"])):
-            total_current.append(np.sum(path["current"][:,i])*1e3)
-            
         fig, ax = plt.subplots()
-        ax.plot(path["time"],total_current)
+        ax.plot(path["time"], np.nansum(path["current"][:],0)*1e3)
         ax.set_xlabel("Number of turns")
         ax.set_ylabel("total current (mA)")
         
     elif dataset == "emit":
-        option_dict = {"x":0, "y":1, "s":2} #input option
-        axis = option_dict[option]
-        scale = [1e12, 1e12, 1e15]
-        label = ["$\\epsilon_{x}$ (pm.rad)",
-                 "$\\epsilon_{y}$ (pm.rad)",
-                 "$\\epsilon_{s}$ (fm.rad)"]
+        dimension_dict = {"x":0, "y":1, "s":2} 
+        axis = dimension_dict[dimension]
+        label = ["$\\epsilon_{x}$ (m.rad)",
+                 "$\\epsilon_{y}$ (m.rad)",
+                 "$\\epsilon_{s}$ (m.rad)"]
         
         if stat_var == "mean":
-            mean_emit = []
-            for i in range (len(path["time"])):
-                mean_emit.append(np.mean(path["emit"][axis,:,i])*scale[axis])
-            
             fig, ax = plt.subplots()
-            ax.plot(path["time"],mean_emit)
+            ax.plot(path["time"], np.nanmean(path["emit"][axis,:],0))
             
         elif stat_var == "std":
-            std_emit = []
-            for i in range (len(path["time"])):
-                std_emit.append(np.std(path["emit"][axis,:,i])*scale[axis])
-              
             fig, ax = plt.subplots() 
-            ax.plot(path["time"],std_emit)
+            ax.plot(path["time"], np.nanstd(path["emit"][axis,:],0))
             
         ax.set_xlabel("Number of turns")
         ax.set_ylabel(stat_var+" " + label[axis])
         
     elif dataset == "mean" or dataset == "std":
-        option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-        axis = option_dict[option]
+        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+        axis = dimension_dict[dimension]
         scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
         label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                       "$\\tau$ (ps)", "$\\delta$"]
         
         fig, ax = plt.subplots()
-        if stat_var == "mean":
-            mean_list = []
-            for i in range (len(path["time"])):
-                mean_list.append(np.mean(path[dataset][axis,:,i]*scale[axis]))
-                
-            ax.plot(path["time"],mean_list)
+        if stat_var == "mean":   
+            ax.plot(path["time"], np.nanmean(path[dataset][axis,:],0)*scale[axis])
             label_sup = {"mean":"", "std":"std of "} # input stat_var
             
-        elif stat_var == "std":
-            std_list = []
-            for i in range (len(path["time"])):
-                std_list.append(np.std(path[dataset][axis,:,i]*scale[axis]))
-            
-            ax.plot(path["time"],std_list)
+        elif stat_var == "std":      
+            ax.plot(path["time"], np.nanstd(path[dataset][axis,:],0)*scale[axis])
             label_sup = {"mean":"", "std":"std of "} #input stat_var
             
         ax.set_xlabel("Number of turns")
-- 
GitLab