# -*- coding: utf-8 -*-
"""
Module for plotting the data recorded by the monitor module during the 
tracking.

@author: Watanyu Foosang
@Date: 10/04/2020
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import h5py as hp
import random

def plot_beamdata(filename, dataset, dimension=None, stat_var=None, x_var="time"):
    """
    Plot data recorded by BeamMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    dataset : {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted.
    dimension : str, optional
        The dimension of the dataset to plot. Use "None" for "current",
        otherwise use the following : 
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
    stat_var : {"mean", "std"}, optional
        Statistical value of the dimension. Unless dataset = "current", stat_var
        needs to be specified.
    x_var : str, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file["Beam"]
    
    if dataset == "current":
        fig, ax = plt.subplots()
        ax.plot(path["time"], np.nansum(path["current"][:],0)*1e3)
        ax.set_xlabel("Number of turns")
        ax.set_ylabel("total current (mA)")
        
    elif dataset == "emit":
        dimension_dict = {"x":0, "y":1, "s":2} 
        axis = dimension_dict[dimension]
        label = ["$\\epsilon_{x}$ (m.rad)",
                 "$\\epsilon_{y}$ (m.rad)",
                 "$\\epsilon_{s}$ (m.rad)"]
        
        if stat_var == "mean":
            fig, ax = plt.subplots()
            ax.plot(path["time"], np.nanmean(path["emit"][axis,:],0))
            
        elif stat_var == "std":
            fig, ax = plt.subplots() 
            ax.plot(path["time"], np.nanstd(path["emit"][axis,:],0))
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(stat_var+" " + label[axis])
        
    elif dataset == "mean" or dataset == "std":
        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
        axis = dimension_dict[dimension]
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                      "$\\tau$ (ps)", "$\\delta$"]
        
        fig, ax = plt.subplots()
        if stat_var == "mean":   
            ax.plot(path["time"], np.nanmean(path[dataset][axis,:],0)*scale[axis])
            label_sup = {"mean":"", "std":"std of "} # input stat_var
            
        elif stat_var == "std":      
            ax.plot(path["time"], np.nanstd(path[dataset][axis,:],0)*scale[axis])
            label_sup = {"mean":"", "std":"std of "} #input stat_var
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis])
    
    file.close()
    return fig    
              
def plot_bunchdata(filename, bunch_number, dataset, dimension="x", x_var="time"):
    """
    Plot data recorded by BunchMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
    dataset : {"current", "emit", "mean", "std", "cs_invariant"}
        HDF5 file's dataset to be plotted.
    dimension : str, optional
        The dimension of the dataset to plot. Use "None" for "current",
        otherwise use the following : 
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"},
            for "action", dimension = {"x","y"}.
    x_var : {"time", "current"}, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "BunchData_{0}".format(bunch_number)  # Data group of the HDF5 file
    
    if dataset == "current":
        y_var = file[group][dataset][:]*1e3
        label = "current (mA)"
        
    elif dataset == "emit":
        dimension_dict = {"x":0, "y":1, "s":2}
                         
        y_var = file[group][dataset][dimension_dict[dimension]]*1e9
        
        if dimension == "x": label = "hor. emittance (nm.rad)"
        elif dimension == "y": label = "ver. emittance (nm.rad)"
        elif dimension == "s": label = "long. emittance (nm.rad)"
        
        
    elif dataset == "mean" or dataset == "std":                        
        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
        axis_index = dimension_dict[dimension]
        
        y_var = file[group][dataset][axis_index]*scale[axis_index]
        if dataset == "mean":
            label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
                          "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
        else:
            label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
                          "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
                          "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
        
        label = label_list[axis_index]
        
    elif dataset == "cs_invariant":
        dimension_dict = {"x":0, "y":1}
        axis_index = dimension_dict[dimension]
        y_var = file[group][dataset][axis_index]
        label_list = ['$J_x$ (m)', '$J_y$ (m)']
        label = label_list[axis_index]
        
    if x_var == "current":
        x_axis = file[group]["current"][:] * 1e3
        xlabel = "current (mA)"
    elif x_var == "time":
        x_axis = file[group]["time"][:]
        xlabel = "number of turns"
        
    fig, ax = plt.subplots()        
    ax.plot(x_axis,y_var)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(label)
    
    file.close()
    return fig
            
def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
                        only_alive=True, plot_size=1, plot_kind='kde'):
    """
    Plot data recorded by PhaseSpaceMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'PhaseSpaceMonitor' object.
    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}
        If dataset is "particles", the variables to be plotted on the 
        horizontal and the vertical axes need to be specified.
    turn : int
        Turn at which the data will be extracted.
    only_alive : bool, optional
        When only_alive is True, only alive particles are plotted and dead 
        particles will be discarded.
    plot_size : [0,1], optional
        Number of macro-particles to plot relative to the total number 
        of macro-particles recorded. This option helps reduce processing time
        when the data is big.
    plot_kind : {'scatter', 'kde', 'hex', 'reg', 'resid'}, optional
        The plot style. The default is 'kde'. 
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.
    """
    
    file = hp.File(filename, "r")
    
    group = "PhaseSpaceData_{0}".format(bunch_number)
    dataset = "particles"

    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e3,1e3,1e3,1e3,1e12,1]
    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
             "$\\delta$"]
    
    # find the index of "turn" in time array
    turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
    
    if len(turn_index[0]) == 0:
        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
                         format(turn, file[group]["time"][:]))     
    
    path = file[group][dataset]
    mp_number = path[:,0,0].size
    
    if only_alive is True:
        index = np.where(file[group]["alive"][:,turn_index])[0]
    else:
        index = np.arange(mp_number)
        
    if plot_size == 1:
        samples = index
    elif plot_size < 1:
        samples_meta = random.sample(list(index), int(plot_size*mp_number))
        samples = sorted(samples_meta)
    else:
        raise ValueError("plot_size must be in range [0,1].")
            
    # format : sns.jointplot(x_axis, yaxis, kind)
    x_axis = path[samples,option_dict[x_var],turn_index[0][0]]
    y_axis = path[samples,option_dict[y_var],turn_index[0][0]]    
        
    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
                        y_axis*scale[option_dict[y_var]], kind=plot_kind)
   
    plt.xlabel(label[option_dict[x_var]])
    plt.ylabel(label[option_dict[y_var]])
            
    file.close()
    return fig

def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
                     stop=None, step=None, profile_plot=True, streak_plot=True):
    """
    Plot data recorded by ProfileMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'ProfileMonitor' object.
    dimension : str, optional
        Dimension to plot. The default is "tau"
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'ProfileMonitor' object, i.e. step % save_every == 0. If None, step is
        equivalent to save_every.
    profile_plot : bool, optional
        If Ture, bunch profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['ProfileData_{0}'.format(bunch_number)]
    l_bound = path["{0}_bin".format(dimension)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
    label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
             "$\\tau$ (ps)", "$\\delta$"]
    
    num = int((stop - start)/step)
    n_bin = len(path[dimension][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = 0.5*(l_bound[:-1,turn_index]+l_bound[1:,turn_index])
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*scale[dimension_dict[dimension]],
                    path[dimension][:,turn_index_array[i]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel(label[dimension_dict[dimension]])
        ax.set_ylabel("number of macro-particles")         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[dimension][:,turn_index_array])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*scale[dimension_dict[dimension]],
                               x_var.max()*scale[dimension_dict[dimension]],
                               y_var.min(),y_var.max()])
        ax2.set_xlabel(label[dimension_dict[dimension]])
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label("Number of macro-particles") 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
                     stop=None, step=None, profile_plot=False, streak_plot=True,
                     bunch_profile=False, dipole=False):
    """
    Plot data recorded by WakePotentialMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'WakePotentialMonitor' object.
    wake_type : str, optional
        Wake type to plot: "Wlong", "Wxdip", ... 
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'WakePotentialMonitor' object, i.e. step % save_every == 0. If None, 
        step is equivalent to save_every.
    profile_plot : bool, optional
        If Ture, wake potential profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.
    bunch_profile : bool, optional.
        If True, the bunch profile is plotted.
    dipole : bool, optional
        If True, the dipole moment is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['WakePotentialData_{0}'.format(bunch_number)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"Wlong":0, "Wxdip":1, "Wydip":2, "Wxquad":3, "Wyquad":4}
    scale = [1e-12, 1e-12, 1e-12, 1e-15, 1e-15]
    label = ["$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$", "$W_{p,x}^Q (V/pC/mm)$",
             "$W_{p,y}^Q (V/pC/mm)$"]
    
    if bunch_profile == True:
        tau_name = "tau_" + wake_type
        wake_type = "profile_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["$\\rho$ (a.u.)"]
        
    elif dipole == True:
        tau_name = "tau_" + wake_type
        wake_type = "dipole_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["Dipole moment (m)"]
        
    else:
        tau_name = "tau_" + wake_type
        
    num = int((stop - start)/step)
    n_bin = len(path[wake_type][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = path[tau_name][:,turn_index]
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*1e12,
                    path[wake_type][:,turn_index_array[i]]*scale[dimension_dict[wake_type]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel("$\\tau$ (ps)")
        ax.set_ylabel(label[dimension_dict[wake_type]])         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[wake_type][:,turn_index_array]*scale[dimension_dict[wake_type]])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*1e12,
                               x_var.max()*1e12,
                               y_var.min(),y_var.max()])
        ax2.set_xlabel("$\\tau$ (ps)")
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label(label[dimension_dict[wake_type]]) 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_tunedata(filename, bunch_number):
    """
    Plot data recorded by TuneMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "TuneData_{0}".format(bunch_number)  # Data group of the HDF5 file
    time = file[group]["time"]
    tune = file[group]["tune"]
    tune_spread = file[group]["tune_spread"]
        
    fig1, ax1 = plt.subplots()        
    ax1.errorbar(x=time[1:], y=tune[0,1:], yerr=tune_spread[0,1:])
    ax1.errorbar(x=time[1:], y=tune[1,1:], yerr=tune_spread[1,1:])
    ax1.set_xlabel("Turn number")
    ax1.set_ylabel("Transverse tunes")
    plt.legend(["x","y"])
    
    fig2, ax2 = plt.subplots()        
    ax2.errorbar(x=time[1:], y=tune[2,1:], yerr=tune_spread[2,1:])
    ax2.set_xlabel("Turn number")
    ax2.set_ylabel("Synchrotron tune")
    
    file.close()
    return (fig1, fig2)