From 37c0a67f6990e87847e6dd002a0cfed6e376e804 Mon Sep 17 00:00:00 2001
From: Watanyu Foosang <watanyu.f@gmail.com>
Date: Thu, 9 Apr 2020 12:24:45 +0200
Subject: [PATCH] Adding plotting methods in Beam class

Two plotting methods, "plot_bunchdata" and "plot_phasespacedata", have been added to Beam class.
"getplot" method in Bunch class has been modified and renamed to "long_phasespace".
---
 tracking/particles.py | 186 +++++++++++++++++++++++++++++++++++++++---
 1 file changed, 175 insertions(+), 11 deletions(-)

diff --git a/tracking/particles.py b/tracking/particles.py
index 5de2d6c..9c2fdf8 100644
--- a/tracking/particles.py
+++ b/tracking/particles.py
@@ -323,23 +323,30 @@ class Bunch:
         
     def getplot(self,x,y,p_type):
         """
-        Plot the parameters from particles object.
+        Plot longitudinal phase space.
         
         Parameters
         ----------
-        x: str, name from Bunch object to plot on hor. axis.
-        y: str, name from Bunch object to plot on ver. axis.
-        p_type: str, "sc" for a scatter plot or "j" for a joint plot.
+        x_var : str 
+            name from Bunch object to plot on horizontal axis.
+        y_var : str 
+            name from Bunch object to plot on ver. axis.
+        plot_type : str {"j" , "sc"} 
+            type of the plot. The defualt value is "j" for a joint plot.
+            Can be modified to "sc" for a scatter plot.
         """
-   
-        x = self.particles[x]
-        y = self.particles[y]
         
-        if p_type == "sc":
-            plt.scatter(x,y)
+        if plot_type == "sc":
+            plt.scatter(self.particles["tau"]*1e12,
+                        self.particles["delta"])
+            plt.xlabel("$\\tau$ (ps)")
+            plt.ylabel("$\\delta$")
         
-        elif p_type == "j":
-            sns.jointplot(x,y,kind="kde")
+        else:
+            sns.jointplot(self.particles["tau"]*1e12,
+                          self.particles["delta"],kind="kde")
+            plt.xlabel("$\\tau$ (ps)")
+            plt.ylabel("$\\delta$")
         
 class Beam:
     """
@@ -597,7 +604,164 @@ class Beam:
         self.mpi_switch = False
         self.mpi = None
         
+    def long_phasespace(self,bunch_number,x_var="tau",y_var="delta",
+                        plot_type="j"):
+        """
+        Plot longitudinal phase space.
+        
+        Parameters
+        ----------
+        bunch_number : int
+            Specify a bunch among those in beam object to be displayed.
+            The value must not exceed the total length of filling_pattern object.
+        x_var : str 
+            name from Bunch object to plot on horizontal axis.
+        y_var : str 
+            name from Bunch object to plot on ver. axis.
+        plot_type : str {"j" , "sc"} 
+            type of the plot. The defualt value is "j" for a joint plot.
+            Can be modified to "sc" for a scatter plot.
+        """
+        
+        if plot_type == "sc":
+            plt.scatter(self[bunch_number-1]["tau"]*1e12,
+                        self[bunch_number-1]["delta"])
+            plt.xlabel("$\\tau$ (ps)")
+            plt.ylabel("$\\delta$")
+        
+        else:
+            sns.jointplot(self[bunch_number-1]["tau"]*1e12,
+                          self[bunch_number-1]["delta"],kind="kde")
+            plt.xlabel("$\\tau$ (ps)")
+            plt.ylabel("$\\delta$")
+            
+    def plot_bunchdata(self, filename, detaset, x_var="time"):
+        """
+        Plot the evolution of the variables from the Beam object.
+        
+        Parameters
+        ----------
+        filename : str
+            Name of the HDF5 file that contains the data.
+        detaset : str {"current","emit","mean","std"}
+            HDF5 file's dataset to be plotted. 
+        x_var : str, optional
+            The variable to be plotted on horizontal axis. The default is "time".
+
+        """
+        
+        file = hp.File(filename, "r")
+        
+        group = "BunchData_0"  # Data group of the HDF5 file
+        
+        if detaset == "current":
+            y_var = file[group][detaset][:]*1e3
+            label = "current (mA)"
+            
+        elif detaset == "emit":
+            axis = int(input("""Specify the axis by entering the corresponding number :
+                             horizontal axis (x)   -> 0 
+                             vertical axis (y)     -> 1 
+                             longitudinal axis (s) -> 2 
+                             : """))
+                             
+            y_var = file[group][detaset][axis]*1e9
+            
+            if axis == 0: label = "hor. emittance (nm.rad)"
+            elif axis == 1: label = "ver. emittance (nm.rad)"
+            elif axis == 2: label = "long. emittance (nm.rad)"
+            
+            
+        elif detaset == "mean" or "std":
+            param = int(input("""Specify the variable by entering the corresponding number: \
+                               x     -> 0 \
+                               xp    -> 1 \
+                               y     -> 2 \
+                               yp    -> 3 \
+                               tau   -> 4 \
+                               delta -> 5 \
+                               :"""))
+            if param == 0:
+                y_var = file[group][detaset][param]*1e6
+                label = "x (um)"
+            elif param == 1:
+                y_var = file[group][detaset][param]*1e6
+                label = "x' ($\\mu$rad)"
+            elif param == 2:
+                y_var = file[group][detaset][param]*1e6
+                label = "y (um)"
+            elif param == 3:
+                y_var = file[group][detaset][param]*1e6
+                label = "y' ($\\mu$rad)"
+            elif param == 4:
+                y_var = file[group][detaset][param]*1e12
+                label = "$\\tau$ (ps)"
+            else :
+                y_var = file[group][detaset][param]
+                label = "$\\delta$"
+                
+        plt.plot(file[group]["time"][:],y_var)
+        plt.xlabel("number of turns")
+        plt.ylabel(label)
+        
+                
+    def plot_phasespacedata(self, filename, total_size, save_every, dataset):
+        """
+        Plot data from PhaseSpaceData_0 group of the HDF5 file.
+
+        Parameters
+        ----------
+        filename : str
+            Name of the HDF5 file that contains the data.
+        total_size : int
+            Total size of the save regarding to 'PhaseSpaceMonitor' object.
+        save_every : int
+            The frequency of the save regarding to 'PhaseSpaceMonitor' object.
+        dataset : str {'alive', 'partcicles'}
+            HDF5 file's dataset to be plotted.
+
+        """
+        
+        file = hp.File(filename, "r")
+        
+        data_points = int(total_size/save_every)
+        timelist = file["PhaseSpaceData_0"]["time"][0:data_points]
+        
+        if dataset == "alive":
+            alive_at_a_time = []
+            for i in range (data_points):
+                alive_at_a_time.append(np.sum(file["PhaseSpaceData_0"][dataset][:,i]))
+                
+            plt.plot(timelist,alive_at_a_time)
+            plt.xlabel("number of turns")
+            plt.ylabel("number of alive particles")
+            
+        elif dataset == "particles":
+            print("""Specify the parameter to be plotted by entering the corresponding number:
+                  x     -> 0 
+                  xp    -> 1 
+                  y     -> 2 
+                  yp    -> 3 
+                  tau   -> 4 
+                  delta -> 5  """)
+                  
+            x_var = int(input("Horizontal axis: "))
+            y_var = int(input("Vertical axis: "))
+            
+            print("Specify the time at turn =", timelist)
+            turn = int(input(": "))
+            index_in_timelist = np.where(timelist==turn)
+
+            sns.jointplot(file["PhaseSpaceData_0"][dataset][:,x_var,index_in_timelist],
+                          file["PhaseSpaceData_0"][dataset][:,y_var,index_in_timelist],
+                          kind="kde")
+            
+            name_list = ["x (m)","xp (rad)","y (m)","yp (rad)","$\\tau$ (s)","$\\delta$"]
+            plt.xlabel(name_list[x_var])
+            plt.ylabel(name_list[y_var])
         
+   
+       
         
         
     
-- 
GitLab