From e5d98ee00630413ac8f4138b7085901b213f7fd8 Mon Sep 17 00:00:00 2001
From: Watanyu Foosang <watanyu.f@gmail.com>
Date: Mon, 20 Apr 2020 10:48:05 +0200
Subject: [PATCH] Improvements on plot_phasespacedata function

- total_size and save_every input parameters have been removed.
- only_alive option function has been changed
- Some small corrections on scaling of delta, bunch_charge, bunch_current
---
 tracking/particles.py | 35 ++++++++++++++---------
 tracking/plotting.py  | 65 ++++++++++++++++++++-----------------------
 2 files changed, 52 insertions(+), 48 deletions(-)

diff --git a/tracking/particles.py b/tracking/particles.py
index 06aa161..fc52fa8 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -336,8 +336,8 @@ class Bunch:
             Can be modified to "sc" for a scatter plot.
         """
         
-        label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", "yp":"y' (mrad)",
-                      "tau":"$\\tau$ (ps)", "delta":"$\\delta$"}
+        label_dict = {"x":"x (mm)", "xp":"x' (mrad)", "y":"y (mm)", 
+                      "yp":"y' (mrad)","tau":"$\\tau$ (ps)", "delta":"$\\delta$"}
         scale = {"x": 1e3, "xp":1e3, "y":1e3, "yp":1e3, "tau":1e12, "delta":1}
         
         if plot_type == "sc":
@@ -346,11 +346,13 @@ class Bunch:
             plt.xlabel(label_dict[x_var])
             plt.ylabel(label_dict[y_var])
         
-        else:
+        elif plot_type == "j": 
             sns.jointplot(self.particles[x_var]*scale[x_var],
                           self.particles[y_var]*scale[y_var],kind="kde")
             plt.xlabel(label_dict[x_var])
             plt.ylabel(label_dict[y_var])
+            
+        else: raise ValueError("Plot type not recognised")
         
 class Beam:
     """
@@ -614,11 +616,14 @@ class Beam:
 
         Parameters
         ----------
-        var : str {"bunch_currebt", "bunch_charge", "bunch_particle", "bunch_mean", "bunch_std", "bunch_emit"}
+        var : str {"bunch_currebt", "bunch_charge", "bunch_particle", 
+                   "bunch_mean", "bunch_std", "bunch_emit"}
             Variable to be plotted.
         option : str, optional
-            If var is "bunch_mean", "bunch_std", or "bunch_emit, option needs tobe specified. 
-            For "bunch_mean" and "bunch_std", option = {"x","xp","y","yp","tau","delta"}.
+            If var is "bunch_mean", "bunch_std", or "bunch_emit, option needs 
+            to be specified.
+            For "bunch_mean" and "bunch_std", 
+                option = {"x","xp","y","yp","tau","delta"}.
             For "bunch_emit", option = {"x","y","s"}.
             The default is None.
 
@@ -633,12 +638,12 @@ class Beam:
         
         if var == "bunch_mean" or var == "bunch_std":
             value_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1e6]
+            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
             label_mean = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                      "$\\tau$ (ps)", "$\\delta(\\times 10^{-6})$"]
+                      "$\\tau$ (ps)", "$\\delta$"]
             label_std = ["std x (um)", "std x' ($\\mu$rad)", "std y (um)",
                         "std y' ($\\mu$rad)", "std $\\tau$ (ps)",
-                        "std $\\delta(\\times 10^{-6})$"]
+                        "std $\\delta$"]
            
             y_axis = var_dict[var][value_dict[option]]
             
@@ -674,12 +679,16 @@ class Beam:
             plt.xlabel('bunch number')
             plt.ylabel(label_y)
                 
-        elif var == "bunch_current" or var == "bunch_charge" or  var =="bunch_particle":
-            plt.plot(np.arange(len(self.filling_pattern)), var_dict[var]) 
+        elif var=="bunch_current" or var=="bunch_charge" or var=="bunch_particle":
+            scale = {"bunch_current":1e3, "bunch_charge":1e9, 
+                     "bunch_particle":1}
+            
+            plt.plot(np.arange(len(self.filling_pattern)), var_dict[var]*
+                     scale[var]) 
             plt.xlabel('bunch number')
             
-            if var == "bunch_current": label_y = "bunch current (A)"
-            elif var == "bunch_charge": label_y = "bunch chagre (C)"
+            if var == "bunch_current": label_y = "bunch current (mA)"
+            elif var == "bunch_charge": label_y = "bunch chagre (nC)"
             else: label_y = "number of particles"
 
             plt.ylabel(label_y)             
diff --git a/tracking/plotting.py b/tracking/plotting.py
index a194acc..ac143f5 100644
--- a/tracking/plotting.py
+++ b/tracking/plotting.py
@@ -68,9 +68,9 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
         # delta -> 5 
                            
         var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1e6]
+        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
         label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
-                      "$\\tau$ (ps)", "$\\delta(\\times 10^{-6})$"]
+                      "$\\tau$ (ps)", "$\\delta$"]
         
         y_var = file[group][dataset][var_dict[option]]*scale[var_dict[option]]
         label = label_list[var_dict[option]]
@@ -82,8 +82,8 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
     
     file.close()
             
-def plot_phasespacedata(filename, bunch_number, total_size, save_every, dataset, 
-                        x_var=None, y_var=None, turn=None, only_alive=False):
+def plot_phasespacedata(filename, bunch_number, dataset, x_var=None, 
+                        y_var=None, turn=None, only_alive=True):
     """
     Plot data from PhaseSpaceData_0 group of the HDF5 file.
 
@@ -94,10 +94,6 @@ def plot_phasespacedata(filename, bunch_number, total_size, save_every, dataset,
     bunch_number : int
         Number of the bunch whose data has been saved in the HDF5 file.
         This has to be identical to 'bunch_number' parameter in 'PhaseSpaceMonitor' object.
-    total_size : int
-        Total size of the save regarding to 'PhaseSpaceMonitor' object.
-    save_every : int
-        Frequency of the save regarding to 'PhaseSpaceMonitor' object.
     dataset : str {'alive', 'partcicles'}
         HDF5 file's dataset to be plotted.
     x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}, optional
@@ -114,15 +110,12 @@ def plot_phasespacedata(filename, bunch_number, total_size, save_every, dataset,
     
     group = "PhaseSpaceData_{0}".format(bunch_number)
     
-    data_points = int(total_size/save_every)
-    timelist = file[group]["time"][0:data_points] # example output: [0,10,20,30,40]
-    
     if dataset == "alive":
         alive_at_a_time = []
-        for i in range (data_points):
+        for i in range (len(file[group]["time"])):
             alive_at_a_time.append(np.sum(file[group][dataset][:,i]))
             
-        plt.plot(timelist,alive_at_a_time)
+        plt.plot(file[group]["time"],alive_at_a_time)
         plt.xlabel("number of turns")
         plt.ylabel("number of alive particles")
         
@@ -136,36 +129,38 @@ def plot_phasespacedata(filename, bunch_number, total_size, save_every, dataset,
         #     delta -> 5  
              
         var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-        scale = [1e3,1e3,1e3,1e3,1e12,1e3]
-        label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)","$\\delta \\times 10^{-3}$"]
+        scale = [1e3,1e3,1e3,1e3,1e12,1]
+        label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
+                 "$\\delta$"]
+        
         
-        # find the index of "turn" in timelist
-        index_in_timelist = np.where(timelist==turn) 
+        # find the index of "turn" in time array
+        turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
+        
+        if len(turn_index[0]) == 0:
+            raise ValueError("Turn {0} is not found. Enter turn from {1}.".
+                             format(turn, file[group]["time"][:]))
+
+        else : pass        
         
         path = file[group][dataset]
+
                         
-        if only_alive == False:
+        if only_alive is False:
             # format : sns.jointplot(x_axis, yaxis, kind)
-            x_axis = path[:,var_dict[x_var],index_in_timelist]
-            y_axis = path[:,var_dict[y_var],index_in_timelist]
+            x_axis = path[:,var_dict[x_var],turn_index[0][0]]
+            y_axis = path[:,var_dict[y_var],turn_index[0][0]]
+
 
+        elif only_alive is True:
+            alive_index = np.where(file[group]["alive"][:,turn_index])
 
-        elif only_alive == True:
-            x_alive = []
-            y_alive = []
-            for i in path[:,var_dict[x_var],index_in_timelist]:
-                if bool(i) == True:
-                    x_alive.append(i)
-                else: pass
-            for i in path[:,var_dict[y_var],index_in_timelist]:
-                if bool(i) == True:
-                    y_alive.append(i)
-                else: pass
-            x_axis = np.array(x_alive)
-            y_axis = np.array(y_alive)
+            x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]]
+            y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]]            
+            
             
-        sns.jointplot(x_axis*scale[var_dict[x_var]], y_axis*scale[var_dict[y_var]],
-                      kind="kde")
+        sns.jointplot(x_axis*scale[var_dict[x_var]], 
+                      y_axis*scale[var_dict[y_var]], kind="kde")
         
         plt.xlabel(label[var_dict[x_var]])
         plt.ylabel(label[var_dict[y_var]])
-- 
GitLab