From b4350aa3e719be59a47befbfb59ef50623ede702 Mon Sep 17 00:00:00 2001
From: Gamelin Alexis <gamelin@synchrotron-soleil.fr>
Date: Fri, 24 Apr 2020 17:46:22 +0200
Subject: [PATCH] Rebase on master and rework plot_phasespacedata

Rebase test_plot on master
plot_phasespacedata -> only "particles" dataset is plotted
plot_phasespacedata -> mandatory arguments
---
 tracking/particles.py |  25 +++++----
 tracking/plotting.py  | 117 ++++++++++++++++++------------------------
 2 files changed, 64 insertions(+), 78 deletions(-)

diff --git a/tracking/particles.py b/tracking/particles.py
index fc52fa8..cb4c046 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -88,6 +88,8 @@ class Bunch:
     -------
     init_gaussian(cov=None, mean=None, **kwargs)
         Initialize bunch particles with 6D gaussian phase space.
+    plot_phasespace(x_var="tau", y_var="delta", plot_type="j")
+        Plot phase space.
         
     References
     ----------
@@ -321,18 +323,18 @@ class Bunch:
         ax = fig.gca()
         ax.plot(bins.mid, profile)
         
-    def getplot(self,x,y,p_type):
+    def plot_phasespace(self, x_var="tau", y_var="delta", plot_type="j"):
         """
-        Plot longitudinal phase space.
+        Plot phase space.
         
         Parameters
         ----------
         x_var : str 
-            name from Bunch object to plot on horizontal axis.
+            Dimension to plot on horizontal axis.
         y_var : str 
-            name from Bunch object to plot on ver. axis.
+            Dimension to plot on vertical axis.
         plot_type : str {"j" , "sc"} 
-            type of the plot. The defualt value is "j" for a joint plot.
+            Type of the plot. The defualt value is "j" for a joint plot.
             Can be modified to "sc" for a scatter plot.
         """
         
@@ -352,7 +354,8 @@ class Bunch:
             plt.xlabel(label_dict[x_var])
             plt.ylabel(label_dict[y_var])
             
-        else: raise ValueError("Plot type not recognised")
+        else: 
+            raise ValueError("Plot type not recognised.")
         
 class Beam:
     """
@@ -390,7 +393,6 @@ class Beam:
         Status of MPI parallelisation, should not be changed directly but with
         mpi_init() and mpi_close()
         
-        
     Methods
     ------
     init_beam(filling_pattern, current_per_bunch=1e-3, mp_per_bunch=1e3)
@@ -404,6 +406,8 @@ class Beam:
         all processors. Rather slow
     mpi_close()
         Call mpi_gather and switch off MPI parallelisation
+    plot(var, option=None)
+        Plot variables with respect to bunch number.
     """
     
     def __init__(self, ring, bunch_list=None):
@@ -610,13 +614,13 @@ class Beam:
         self.mpi_switch = False
         self.mpi = None
         
-    def plot_bunchnumber(self, var, option=None):
+    def plot(self, var, option=None):
         """
-        Plot varviables with respect to bunch number.
+        Plot variables with respect to bunch number.
 
         Parameters
         ----------
-        var : str {"bunch_currebt", "bunch_charge", "bunch_particle", 
+        var : str {"bunch_current", "bunch_charge", "bunch_particle", 
                    "bunch_mean", "bunch_std", "bunch_emit"}
             Variable to be plotted.
         option : str, optional
@@ -626,7 +630,6 @@ class Beam:
                 option = {"x","xp","y","yp","tau","delta"}.
             For "bunch_emit", option = {"x","y","s"}.
             The default is None.
-
         """
         
         var_dict = {"bunch_current":self.bunch_current,
diff --git a/tracking/plotting.py b/tracking/plotting.py
index ac143f5..9ae1449 100644
--- a/tracking/plotting.py
+++ b/tracking/plotting.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 """
-Module for plotting Bunch and Beam objects.
+Module for plotting the data recorded by the monitor module during the 
+tracking.
 
 @author: Watanyu Foosang
 @Date: 10/04/2020
@@ -11,10 +12,9 @@ import matplotlib.pyplot as plt
 import seaborn as sns
 import h5py as hp
 
-
 def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
     """
-    Plot the evolution of the variables from the Beam object.
+    Plot data recorded from a BunchMonitor.
     
     Parameters
     ----------
@@ -82,10 +82,10 @@ def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
     
     file.close()
             
-def plot_phasespacedata(filename, bunch_number, dataset, x_var=None, 
-                        y_var=None, turn=None, only_alive=True):
+def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, 
+                        only_alive=True):
     """
-    Plot data from PhaseSpaceData_0 group of the HDF5 file.
+    Plot data recorded from a PhaseSpaceMonitor.
 
     Parameters
     ----------
@@ -93,78 +93,61 @@ def plot_phasespacedata(filename, bunch_number, dataset, x_var=None,
         Name of the HDF5 file that contains the data.
     bunch_number : int
         Number of the bunch whose data has been saved in the HDF5 file.
-        This has to be identical to 'bunch_number' parameter in 'PhaseSpaceMonitor' object.
-    dataset : str {'alive', 'partcicles'}
-        HDF5 file's dataset to be plotted.
-    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}, optional
-        If dataset is "particles", the variables to be plotted on the horizontal
-        and the vertical axes need to be specified.
-    turn : int, optional
+        This has to be identical to 'bunch_number' parameter in 
+        'PhaseSpaceMonitor' object.
+    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}
+        If dataset is "particles", the variables to be plotted on the 
+        horizontal and the vertical axes need to be specified.
+    turn : int
         Turn at which the data will be extracted.
-    only_alive : bool
-        When only_alive is True, only alive particles are plotted and dead particles will be discarded.
-
+    only_alive : bool, optional
+        When only_alive is True, only alive particles are plotted and dead 
+        particles will be discarded.
     """
     
     file = hp.File(filename, "r")
     
     group = "PhaseSpaceData_{0}".format(bunch_number)
-    
-    if dataset == "alive":
-        alive_at_a_time = []
-        for i in range (len(file[group]["time"])):
-            alive_at_a_time.append(np.sum(file[group][dataset][:,i]))
-            
-        plt.plot(file[group]["time"],alive_at_a_time)
-        plt.xlabel("number of turns")
-        plt.ylabel("number of alive particles")
-        
-    elif dataset == "particles":
-        # Specify the parameter
-        #     x     -> 0 
-        #     xp    -> 1 
-        #     y     -> 2 
-        #     yp    -> 3 
-        #     tau   -> 4 
-        #     delta -> 5  
-             
-        var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
-        scale = [1e3,1e3,1e3,1e3,1e12,1]
-        label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
-                 "$\\delta$"]
-        
-        
-        # find the index of "turn" in time array
-        turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
-        
-        if len(turn_index[0]) == 0:
-            raise ValueError("Turn {0} is not found. Enter turn from {1}.".
-                             format(turn, file[group]["time"][:]))
+    dataset = "particles"
 
-        else : pass        
-        
-        path = file[group][dataset]
-
-                        
-        if only_alive is False:
-            # format : sns.jointplot(x_axis, yaxis, kind)
-            x_axis = path[:,var_dict[x_var],turn_index[0][0]]
-            y_axis = path[:,var_dict[y_var],turn_index[0][0]]
+    # Specify the parameter
+    #     x     -> 0 
+    #     xp    -> 1 
+    #     y     -> 2 
+    #     yp    -> 3 
+    #     tau   -> 4 
+    #     delta -> 5  
+             
+    var_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
+    scale = [1e3,1e3,1e3,1e3,1e12,1]
+    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
+             "$\\delta$"]
+    
+    # find the index of "turn" in time array
+    turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
+    
+    if len(turn_index[0]) == 0:
+        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
+                         format(turn, file[group]["time"][:]))     
+    
+    path = file[group][dataset]
 
+    if only_alive is False:
+        # format : sns.jointplot(x_axis, yaxis, kind)
+        x_axis = path[:,var_dict[x_var],turn_index[0][0]]
+        y_axis = path[:,var_dict[y_var],turn_index[0][0]]
 
-        elif only_alive is True:
-            alive_index = np.where(file[group]["alive"][:,turn_index])
+    elif only_alive is True:
+        alive_index = np.where(file[group]["alive"][:,turn_index])
 
-            x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]]
-            y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]]            
-            
-            
-        sns.jointplot(x_axis*scale[var_dict[x_var]], 
-                      y_axis*scale[var_dict[y_var]], kind="kde")
-        
-        plt.xlabel(label[var_dict[x_var]])
-        plt.ylabel(label[var_dict[y_var]])
+        x_axis = path[alive_index[0],var_dict[x_var],turn_index[0][0]]
+        y_axis = path[alive_index[0],var_dict[y_var],turn_index[0][0]]            
         
+    sns.jointplot(x_axis*scale[var_dict[x_var]], 
+                  y_axis*scale[var_dict[y_var]], kind="kde")
+    
+    plt.xlabel(label[var_dict[x_var]])
+    plt.ylabel(label[var_dict[y_var]])
             
     file.close()
 
-- 
GitLab