# -*- coding: utf-8 -*-
"""
Module for plotting the data recorded by the monitor module during the 
tracking.

@author: Watanyu Foosang
@Date: 10/04/2020
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import h5py as hp

def plot_beamdata(filename, dataset, option=None, stat_var=None, x_var="time"):
    """
    Plot data recorded from a BeamMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    dataset : {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted.
    option : str, optional
        If dataset is "emit", "mean", or "std", the variable name to be plotted
        needs to be specified :
            for "emit", option = {"x","y","s"}
            for "mean" and "std", option = {"x","xp","y","yp","tau","delta"}
    stat_var : {"mean", "std"}, optional
        Statistical value of option. Except when dataset = "current", stat_var
        needs to be specified. The default is None.
    x_var : str, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file["Beam"]
    
    if dataset == "current":
        total_current = []
        for i in range (len(path["time"])):
            total_current.append(np.sum(path["current"][:,i])*1e3)
            
        fig, ax = plt.subplots()
        ax.plot(path["time"],total_current)
        ax.set_xlabel("Number of turns")
        ax.set_ylabel("total current (mA)")
        
    elif dataset == "emit":
        option_dict = {"x":0, "y":1, "s":2} #input option
        axis = option_dict[option]
        scale = [1e12, 1e12, 1e15]
        label = ["$\\epsilon_{x}$ (pm.rad)",
                 "$\\epsilon_{y}$ (pm.rad)",
                 "$\\epsilon_{s}$ (fm.rad)"]
        
        if stat_var == "mean":
            mean_emit = []
            for i in range (len(path["time"])):
                mean_emit.append(np.mean(path["emit"][axis,:,i])*scale[axis])
            
            fig, ax = plt.subplots()
            ax.plot(path["time"],mean_emit)
            
        elif stat_var == "std":
            std_emit = []
            for i in range (len(path["time"])):
                std_emit.append(np.std(path["emit"][axis,:,i])*scale[axis])
              
            fig, ax = plt.subplots() 
            ax.plot(path["time"],std_emit)
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(stat_var+" " + label[axis])
        
    elif dataset == "mean" or dataset == "std":
        option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
        axis = option_dict[option]
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                      "$\\tau$ (ps)", "$\\delta$"]
        
        fig, ax = plt.subplots()
        if stat_var == "mean":
            mean_list = []
            for i in range (len(path["time"])):
                mean_list.append(np.mean(path[dataset][axis,:,i]*scale[axis]))
                
            ax.plot(path["time"],mean_list)
            label_sup = {"mean":"", "std":"std of "} # input stat_var
            
        elif stat_var == "std":
            std_list = []
            for i in range (len(path["time"])):
                std_list.append(np.std(path[dataset][axis,:,i]*scale[axis]))
            
            ax.plot(path["time"],std_list)
            label_sup = {"mean":"", "std":"std of "} #input stat_var
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis])
    
    file.close()
    return fig    
              
def plot_bunchdata(filename, bunch_number, dataset, option=None, x_var="time"):
    """
    Plot data recorded from a BunchMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        The bunch number whose data has been saved in the HDF5 file.
        This has to be identical to 'bunch_number' parameter in 'BunchMonitor' object.
    detaset : str {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted.
    option : str, optional
        If dataset is "emit", "mean", or "std", the variable name to be plotted
        needs to be specified :
            for "emit", option = {"x","y","s"}
            for "mean" and "std", option = {"x","xp","y","yp","tau","delta"}    
    x_var : str, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "BunchData_{0}".format(bunch_number)  # Data group of the HDF5 file
    
    if dataset == "current":
        y_var = file[group][dataset][:]*1e3
        label = "current (mA)"
        
    elif dataset == "emit":
        option_dict = {"x":0, "y":1, "s":2}
                         
        y_var = file[group][dataset][option_dict[option]]*1e9
        
        if option == "x": label = "hor. emittance (nm.rad)"
        elif option == "y": label = "ver. emittance (nm.rad)"
        elif option == "s": label = "long. emittance (nm.rad)"
        
        
    elif dataset == "mean" or "std":                        
        option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
        label_list = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                      "$\\tau$ (ps)", "$\\delta$"]
        
        y_var = file[group][dataset][option_dict[option]]\
                *scale[option_dict[option]]
        label = label_list[option_dict[option]]
        
    fig, ax = plt.subplots()        
    ax.plot(file[group]["time"][:],y_var)
    ax.set_xlabel("number of turns")
    ax.set_ylabel(label)
    
    file.close()
    return fig
            
def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn, 
                        only_alive=True):
    """
    Plot data recorded from a PhaseSpaceMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Number of the bunch whose data has been saved in the HDF5 file.
        This has to be identical to 'bunch_number' parameter in 
        'PhaseSpaceMonitor' object.
    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}
        If dataset is "particles", the variables to be plotted on the 
        horizontal and the vertical axes need to be specified.
    turn : int
        Turn at which the data will be extracted.
    only_alive : bool, optional
        When only_alive is True, only alive particles are plotted and dead 
        particles will be discarded.
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.
    """
    
    file = hp.File(filename, "r")
    
    group = "PhaseSpaceData_{0}".format(bunch_number)
    dataset = "particles"

    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e3,1e3,1e3,1e3,1e12,1]
    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
             "$\\delta$"]
    
    # find the index of "turn" in time array
    turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
    
    if len(turn_index[0]) == 0:
        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
                         format(turn, file[group]["time"][:]))     
    
    path = file[group][dataset]

    if only_alive is False:
        # format : sns.jointplot(x_axis, yaxis, kind)
        x_axis = path[:,option_dict[x_var],turn_index[0][0]]
        y_axis = path[:,option_dict[y_var],turn_index[0][0]]

    elif only_alive is True:
        alive_index = np.where(file[group]["alive"][:,turn_index])

        x_axis = path[alive_index[0],option_dict[x_var],turn_index[0][0]]
        y_axis = path[alive_index[0],option_dict[y_var],turn_index[0][0]]    
        
    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
                        y_axis*scale[option_dict[y_var]], kind="kde")
    
    plt.xlabel(label[option_dict[x_var]])
    plt.ylabel(label[option_dict[y_var]])
            
    file.close()
    return fig