From 9dd804ed1bc69c17ad9903df63baa519a805779c Mon Sep 17 00:00:00 2001
From: Watanyu Foosang <watanyu.f@gmail.com>
Date: Fri, 19 Jun 2020 14:15:39 +0200
Subject: [PATCH] Options added to plot_bunchdata and plot_phasespacedata

- Option to plot bunch current on x-axis has been included in plot_bunchdata.
- Option to select the plot size for plot_phasespacedata has been added.
---
 tracking/monitors/plotting.py | 92 +++++++++++++++++++----------------
 1 file changed, 49 insertions(+), 43 deletions(-)

diff --git a/tracking/monitors/plotting.py b/tracking/monitors/plotting.py
index d77049a..2738321 100644
--- a/tracking/monitors/plotting.py
+++ b/tracking/monitors/plotting.py
@@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
 import matplotlib as mpl
 import seaborn as sns
 import h5py as hp
+import random
 
 def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
     """
@@ -122,14 +123,14 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
     bunch_number : int
         Bunch to plot. This has to be identical to 'bunch_number' parameter in 
         'BunchMonitor' object.
-    detaset : str {"current","emit","mean","std"}
+    detaset : {"current","emit","mean","std"}
         HDF5 file's dataset to be plotted.
     option : str, optional
         If dataset is "emit", "mean", or "std", the variable name to be plotted
         needs to be specified :
             for "emit", option = {"x","y","s"}
             for "mean" and "std", option = {"x","xp","y","yp","tau","delta"}    
-    x_var : str, optional
+    x_var : {"time", "current"}, optional
         Variable to be plotted on the horizontal axis. The default is "time".
         
     Return
@@ -159,24 +160,37 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
         
     elif dataset == "mean" or "std":                        
         option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
-        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
-        label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                      "$\\tau$ (ps)", "$\\delta$"]
+        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
+        axis_index = option_dict[option]
+        
+        y_var = file[group][dataset][axis_index]*scale[axis_index]
+        if dataset == "mean":
+            label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
+                          "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
+        else:
+            label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
+                          "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
+                          "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
         
-        y_var = file[group][dataset][option_dict[option]]\
-                *scale[option_dict[option]]
-        label = label_list[option_dict[option]]
+        label = label_list[axis_index]
+        
+    if x_var == "current":
+        x_axis = file[group]["current"][:] * 1e3
+        xlabel = "current (mA)"
+    elif x_var == "time":
+        x_axis = file[group]["time"][:]
+        xlabel = "number of turns"
         
     fig, ax = plt.subplots()        
-    ax.plot(file[group]["time"][:],y_var)
-    ax.set_xlabel("number of turns")
+    ax.plot(x_axis,y_var)
+    ax.set_xlabel(xlabel)
     ax.set_ylabel(label)
     
     file.close()
     return fig
             
-def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, 
-                        only_alive=True):
+def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
+                        only_alive=True, plot_size=1):
     """
     Plot data recorded by PhaseSpaceMonitor.
 
@@ -195,6 +209,10 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
     only_alive : bool, optional
         When only_alive is True, only alive particles are plotted and dead 
         particles will be discarded.
+    plot_size : [0,1], optional
+        Number of macro-particles to plot relative to the total number 
+        of macro-particles recorded. This option helps reduce processing time
+        when the data is big.
         
     Return
     ------
@@ -220,21 +238,28 @@ def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
                          format(turn, file[group]["time"][:]))     
     
     path = file[group][dataset]
-
-    if only_alive is False:
-        # format : sns.jointplot(x_axis, yaxis, kind)
-        x_axis = path[:,option_dict[x_var],turn_index[0][0]]
-        y_axis = path[:,option_dict[y_var],turn_index[0][0]]
-
-    elif only_alive is True:
-        alive_index = np.where(file[group]["alive"][:,turn_index])
-
-        x_axis = path[alive_index[0],option_dict[x_var],turn_index[0][0]]
-        y_axis = path[alive_index[0],option_dict[y_var],turn_index[0][0]]    
+    mp_number = path[:,0,0].size
+    
+    if only_alive is True:
+        index = np.where(file[group]["alive"][:,turn_index])[0]
+    else:
+        index = np.arange(mp_number)
+        
+    if plot_size == 1:
+        samples = index
+    elif plot_size < 1:
+        samples_meta = random.sample(list(index), int(plot_size*mp_number))
+        samples = sorted(samples_meta)
+    else:
+        raise ValueError("plot_size must be in range [0,1].")
+            
+    # format : sns.jointplot(x_axis, yaxis, kind)
+    x_axis = path[samples,option_dict[x_var],turn_index[0][0]]
+    y_axis = path[samples,option_dict[y_var],turn_index[0][0]]    
         
     fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
                         y_axis*scale[option_dict[y_var]], kind="kde")
-    
+   
     plt.xlabel(label[option_dict[x_var]])
     plt.ylabel(label[option_dict[y_var]])
             
@@ -348,23 +373,4 @@ def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
         return fig
     elif streak_plot is True:
         return fig2
-    
-    
-    
-        
-
-    
-    
-    
-        
-    
-    
-    
-    
-    
-    
-    
-    
-    
-    
     
\ No newline at end of file
-- 
GitLab