# -*- coding: utf-8 -*-
"""
Module for plotting the data recorded by the monitor module during the 
tracking.

@author: Watanyu Foosang
@Date: 10/04/2020
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import h5py as hp
import random
from scipy.fft import rfftfreq

def plot_beamdata(filename, dataset, dimension=None, stat_var=None, x_var="time"):
    """
    Plot data recorded by BeamMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    dataset : {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted.
    dimension : str, optional
        The dimension of the dataset to plot. Use "None" for "current",
        otherwise use the following : 
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
    stat_var : {"mean", "std"}, optional
        Statistical value of the dimension. Unless dataset = "current", stat_var
        needs to be specified.
    x_var : str, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file["Beam"]
    
    if dataset == "current":
        fig, ax = plt.subplots()
        ax.plot(path["time"], np.nansum(path["current"][:],0)*1e3)
        ax.set_xlabel("Number of turns")
        ax.set_ylabel("total current (mA)")
        
    elif dataset == "emit":
        dimension_dict = {"x":0, "y":1, "s":2} 
        axis = dimension_dict[dimension]
        label = ["$\\epsilon_{x}$ (m.rad)",
                 "$\\epsilon_{y}$ (m.rad)",
                 "$\\epsilon_{s}$ (m.rad)"]
        
        if stat_var == "mean":
            fig, ax = plt.subplots()
            ax.plot(path["time"], np.nanmean(path["emit"][axis,:],0))
            
        elif stat_var == "std":
            fig, ax = plt.subplots() 
            ax.plot(path["time"], np.nanstd(path["emit"][axis,:],0))
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(stat_var+" " + label[axis])
        
    elif dataset == "mean" or dataset == "std":
        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
        axis = dimension_dict[dimension]
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
                      "$\\tau$ (ps)", "$\\delta$"]
        
        fig, ax = plt.subplots()
        if stat_var == "mean":   
            ax.plot(path["time"], np.nanmean(path[dataset][axis,:],0)*scale[axis])
            label_sup = {"mean":"", "std":"std of "} # input stat_var
            
        elif stat_var == "std":      
            ax.plot(path["time"], np.nanstd(path[dataset][axis,:],0)*scale[axis])
            label_sup = {"mean":"", "std":"std of "} #input stat_var
            
        ax.set_xlabel("Number of turns")
        ax.set_ylabel(label_sup[stat_var] + dataset +" "+ label[axis])
    
    file.close()
    return fig    
              
def plot_bunchdata(filename, bunch_number, dataset, dimension="x", x_var="time"):
    """
    Plot data recorded by BunchMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
    dataset : {"current", "emit", "mean", "std", "cs_invariant"}
        HDF5 file's dataset to be plotted.
    dimension : str, optional
        The dimension of the dataset to plot. Use "None" for "current",
        otherwise use the following : 
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"},
            for "action", dimension = {"x","y"}.
    x_var : {"time", "current"}, optional
        Variable to be plotted on the horizontal axis. The default is "time".
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "BunchData_{0}".format(bunch_number)  # Data group of the HDF5 file
    
    if dataset == "current":
        y_var = file[group][dataset][:]*1e3
        label = "current (mA)"
        
    elif dataset == "emit":
        dimension_dict = {"x":0, "y":1, "s":2}
                         
        y_var = file[group][dataset][dimension_dict[dimension]]*1e9
        
        if dimension == "x": label = "hor. emittance (nm.rad)"
        elif dimension == "y": label = "ver. emittance (nm.rad)"
        elif dimension == "s": label = "long. emittance (nm.rad)"
        
        
    elif dataset == "mean" or dataset == "std":                        
        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
        axis_index = dimension_dict[dimension]
        
        y_var = file[group][dataset][axis_index]*scale[axis_index]
        if dataset == "mean":
            label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
                          "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
        else:
            label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
                          "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
                          "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
        
        label = label_list[axis_index]
        
    elif dataset == "cs_invariant":
        dimension_dict = {"x":0, "y":1}
        axis_index = dimension_dict[dimension]
        y_var = file[group][dataset][axis_index]
        label_list = ['$J_x$ (m)', '$J_y$ (m)']
        label = label_list[axis_index]
        
    if x_var == "current":
        x_axis = file[group]["current"][:] * 1e3
        xlabel = "current (mA)"
    elif x_var == "time":
        x_axis = file[group]["time"][:]
        xlabel = "number of turns"
        
    fig, ax = plt.subplots()        
    ax.plot(x_axis,y_var)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(label)
    
    file.close()
    return fig
            
def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
                        only_alive=True, plot_size=1, plot_kind='kde'):
    """
    Plot data recorded by PhaseSpaceMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'PhaseSpaceMonitor' object.
    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}
        If dataset is "particles", the variables to be plotted on the 
        horizontal and the vertical axes need to be specified.
    turn : int
        Turn at which the data will be extracted.
    only_alive : bool, optional
        When only_alive is True, only alive particles are plotted and dead 
        particles will be discarded.
    plot_size : [0,1], optional
        Number of macro-particles to plot relative to the total number 
        of macro-particles recorded. This option helps reduce processing time
        when the data is big.
    plot_kind : {'scatter', 'kde', 'hex', 'reg', 'resid'}, optional
        The plot style. The default is 'kde'. 
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.
    """
    
    file = hp.File(filename, "r")
    
    group = "PhaseSpaceData_{0}".format(bunch_number)
    dataset = "particles"

    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e3,1e3,1e3,1e3,1e12,1]
    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
             "$\\delta$"]
    
    # find the index of "turn" in time array
    turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
    
    if len(turn_index[0]) == 0:
        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
                         format(turn, file[group]["time"][:]))     
    
    path = file[group][dataset]
    mp_number = path[:,0,0].size
    
    if only_alive is True:
        index = np.where(file[group]["alive"][:,turn_index])[0]
    else:
        index = np.arange(mp_number)
        
    if plot_size == 1:
        samples = index
    elif plot_size < 1:
        samples_meta = random.sample(list(index), int(plot_size*mp_number))
        samples = sorted(samples_meta)
    else:
        raise ValueError("plot_size must be in range [0,1].")
            
    # format : sns.jointplot(x_axis, yaxis, kind)
    x_axis = path[samples,option_dict[x_var],turn_index[0][0]]
    y_axis = path[samples,option_dict[y_var],turn_index[0][0]]    
        
    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
                        y_axis*scale[option_dict[y_var]], kind=plot_kind)
   
    plt.xlabel(label[option_dict[x_var]])
    plt.ylabel(label[option_dict[y_var]])
            
    file.close()
    return fig

def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
                     stop=None, step=None, profile_plot=True, streak_plot=True):
    """
    Plot data recorded by ProfileMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'ProfileMonitor' object.
    dimension : str, optional
        Dimension to plot. The default is "tau"
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'ProfileMonitor' object, i.e. step % save_every == 0. If None, step is
        equivalent to save_every.
    profile_plot : bool, optional
        If Ture, bunch profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['ProfileData_{0}'.format(bunch_number)]
    l_bound = path["{0}_bin".format(dimension)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
    label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
             "$\\tau$ (ps)", "$\\delta$"]
    
    num = int((stop - start)/step)
    n_bin = len(path[dimension][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = l_bound[:,turn_index]
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*scale[dimension_dict[dimension]],
                    path[dimension][:,turn_index_array[i]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel(label[dimension_dict[dimension]])
        ax.set_ylabel("number of macro-particles")         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[dimension][:,turn_index_array])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*scale[dimension_dict[dimension]],
                               x_var.max()*scale[dimension_dict[dimension]],
                               y_var.min(),y_var.max()])
        ax2.set_xlabel(label[dimension_dict[dimension]])
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label("Number of macro-particles") 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
                     stop=None, step=None, profile_plot=False, streak_plot=True,
                     bunch_profile=False, dipole=False):
    """
    Plot data recorded by WakePotentialMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'WakePotentialMonitor' object.
    wake_type : str, optional
        Wake type to plot: "Wlong", "Wxdip", ... 
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'WakePotentialMonitor' object, i.e. step % save_every == 0. If None, 
        step is equivalent to save_every.
    profile_plot : bool, optional
        If Ture, wake potential profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.
    bunch_profile : bool, optional.
        If True, the bunch profile is plotted.
    dipole : bool, optional
        If True, the dipole moment is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['WakePotentialData_{0}'.format(bunch_number)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"Wlong":0, "Wxdip":1, "Wydip":2, "Wxquad":3, "Wyquad":4}
    scale = [1e-12, 1e-12, 1e-12, 1e-15, 1e-15]
    label = ["$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$", "$W_{p,x}^Q (V/pC/mm)$",
             "$W_{p,y}^Q (V/pC/mm)$"]
    
    if bunch_profile == True:
        tau_name = "tau_" + wake_type
        wake_type = "profile_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["$\\rho$ (a.u.)"]
        
    elif dipole == True:
        tau_name = "tau_" + wake_type
        wake_type = "dipole_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["Dipole moment (m)"]
        
    else:
        tau_name = "tau_" + wake_type
        
    num = int((stop - start)/step)
    n_bin = len(path[wake_type][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = path[tau_name][:,turn_index]
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*1e12,
                    path[wake_type][:,turn_index_array[i]]*scale[dimension_dict[wake_type]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel("$\\tau$ (ps)")
        ax.set_ylabel(label[dimension_dict[wake_type]])         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[wake_type][:,turn_index_array]*scale[dimension_dict[wake_type]])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*1e12,
                               x_var.max()*1e12,
                               y_var.min(),y_var.max()])
        ax2.set_xlabel("$\\tau$ (ps)")
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label(label[dimension_dict[wake_type]]) 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_tunedata(filename, bunch_number, ring=None, plot_tune=True, plot_fft=False,
                  dimension='x', min_tune=0, max_tune=0.5, min_turn=None, 
                  max_turn=None, streak_plot=True, profile_plot=False):
    """
    Plot data recorded by TuneMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
    ring : Synchrotron object, optional
        The ring configuration that is used in TuneMonitor. If None, the default
        value of the revolution period and the revolution frequency are used,
        which are 1.183 us and 0.845 MHz, respectively.
    plot_tune : bool, optional
        If True, tune data is plotted.
    plot_fft : bool, optional
        If True, FFT data is plotted.
    dimension : {'x', 'y', 's'}
        Option to plot FFT data in horizontal, vertical, or longitudinal plane.
    min_tune, max_tune : int or float, optional
        The minimum and the maximum tune values to plot FFT data.
    min_turn, max_turn : int or float, optional
        The minimum and the maximum number of turns to plot FFT data.
    streak_plot : bool, optional
        If True, the FFT data is plotted as a streak plot.
    bunch_profile : bool, optional.
        If True, the FFT data is plotted as line profiles.
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "TuneData_{0}".format(bunch_number)  # Data group of the HDF5 file
    time = file[group]["time"]
    
    if plot_tune is True:
        tune = file[group]["tune"]
        tune_spread = file[group]["tune_spread"]
            
        fig1, ax1 = plt.subplots()        
        ax1.errorbar(x=time[1:], y=tune[0,1:], yerr=tune_spread[0,1:])
        ax1.errorbar(x=time[1:], y=tune[1,1:], yerr=tune_spread[1,1:])
        ax1.set_xlabel("Turn number")
        ax1.set_ylabel("Transverse tunes")
        plt.legend(["x","y"])
        
        fig2, ax2 = plt.subplots()        
        ax2.errorbar(x=time[1:], y=tune[2,1:], yerr=tune_spread[2,1:])
        ax2.set_xlabel("Turn number")
        ax2.set_ylabel("Synchrotron tune")
        
    if plot_fft is True:
        if ring is None:
            T0 = 1.183e-06
            f0 = 0.845e6
        else:
            T0 = ring.T0
            f0 = ring.f0
        
        n_freq = file[group]['fft'].shape[1]
        freq = rfftfreq((n_freq-1)*2, T0)
        tune_fft = freq / f0
        
        dimension_dict = {'x':0, 'y':1, 's':2}
        axis = dimension_dict[dimension]
        
        fourier_save = file[group]['fft'][axis]
        
        if max_turn is None:
            max_turn = time[-1]
        if min_turn is None:
            min_turn = time[1]
            
        min_tune_iloc = np.where(tune_fft >= min_tune)[0][0]
        max_tune_iloc = np.where(tune_fft <= max_tune)[0][-1]
        save_every = int(time[1] - time[0])
        min_turn_iloc = min_turn // save_every
        max_turn_iloc = max_turn // save_every
        
    
        if streak_plot is True:
            fig3, ax3 = plt.subplots()
            cmap = plt.get_cmap('Blues')
        
            c = ax3.imshow(np.transpose(np.log(
                          fourier_save[min_tune_iloc:max_tune_iloc+1, 
                                       min_turn_iloc:max_turn_iloc+1])),
                          cmap=cmap, origin='lower' , aspect='auto',
                          extent=[min_tune, max_tune, min_turn, max_turn])
            ax3.set_xlabel('$Q_{}$'.format(dimension))
            ax3.set_ylabel("Turns")
            cbar = fig3.colorbar(c, ax=ax3)
            cbar.set_label("log FFT amplitude") 
            
        if profile_plot is True:
            fig4, ax4 = plt.subplots()
            ax4.plot(tune_fft[min_tune_iloc:max_tune_iloc+1], 
                     fourier_save[min_tune_iloc:max_tune_iloc+1,
                                  min_turn_iloc:max_turn_iloc+1])
            ax4.set_xlabel('$Q_{}$'.format(dimension))
            ax4.set_ylabel("FFT amplitude")
            ax4.legend(time[min_turn_iloc:max_turn_iloc+1])
            
    file.close()

   
    if plot_tune is True and plot_fft is True:
        if streak_plot is True and profile_plot is True:
            return fig1, fig2, fig3, fig4
        elif streak_plot is True:
            return fig1, fig2, fig3
        elif profile_plot is True:
            return fig1, fig2, fig4
        
    elif plot_tune is True:
        return fig1, fig2
    
    elif plot_fft is True:
        if streak_plot is True and profile_plot is True:
            return fig3, fig4
        elif streak_plot is True:
            return fig3
        elif profile_plot is True:
            return fig4

def plot_cavitydata(filename, cavity_name, phasor="cavity", 
                    plot_type="bunch", bunch_number=0, turn=None):
    """
    Plot data recorded by CavityMonitor.

    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    cavity_name : str
        Name of the CavityResonator object.
    phasor : str, optional
        Type of the phasor to plot. Can be "beam" or "cavity".
    plot_type : str, optional
        Type of plot:
            - "bunch" plots the phasor voltage and angle versus time for a 
            given bunch.
            - "turn" plots the phasor voltage and ange versus bunch index for
            a given turn.
            - "streak_volt" plots the phasor voltage versus bunch index and 
            time
            - "streak_angle" plots the phasor angle versus bunch index and 
            time
    bunch_number : int, optional
        Bunch number to select. The default is 0.
    turn : int, optional
        Turn to plot. The default is None.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    cavity_data = file[cavity_name]
    
    time = cavity_data["time"]
    
    ph = {"cavity":0, "beam":1}
    labels = ["Cavity", "Beam"]
    
    if plot_type == "bunch":
    
        data = [cavity_data["cavity_phasor_record"][bunch_number,:], 
                cavity_data["beam_phasor_record"][bunch_number,:]]

        ylabel1 = labels[ph[phasor]] + " voltage [MV]"
        ylabel2 = labels[ph[phasor]] + " phase [rad]"
        
        fig, ax = plt.subplots()
        twin = ax.twinx()
        p1, = ax.plot(time, np.abs(data[ph[phasor]])*1e-6, color="r",label=ylabel1)
        p2, = twin.plot(time, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
        ax.set_xlabel("Turn number")
        ax.set_ylabel(ylabel1)
        twin.set_ylabel(ylabel2)
        
        plots = [p1, p2]
        ax.legend(handles=plots, loc="best")
        
        ax.yaxis.label.set_color("r")
        twin.yaxis.label.set_color("b")
        
    if plot_type == "turn":
        
        index = time == turn
        
        ph = {"cavity":0, "beam":1}
        data = [cavity_data["cavity_phasor_record"][:,index], 
                cavity_data["beam_phasor_record"][:,index]]
        labels = ["Cavity", "Beam"]
        
        h=len(data[0])
        x=np.arange(h)

        ylabel1 = labels[ph[phasor]] + " voltage [MV]"
        ylabel2 = labels[ph[phasor]] + " phase [rad]"
        
        fig, ax = plt.subplots()
        twin = ax.twinx()
        p1, = ax.plot(x, np.abs(data[ph[phasor]])*1e-6, color="r",label=ylabel1)
        p2, = twin.plot(x, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
        ax.set_xlabel("Bunch index")
        ax.set_ylabel(ylabel1)
        twin.set_ylabel(ylabel2)
        
        plots = [p1, p2]
        ax.legend(handles=plots, loc="best")
        
        ax.yaxis.label.set_color("r")
        twin.yaxis.label.set_color("b")
        
    if plot_type == "streak_volt" or plot_type == "streak_phase":
        
        if plot_type == "streak_volt":
            data = np.transpose(np.abs(cavity_data["cavity_phasor_record"][:,:])*1e-6)
            ylabel = labels[ph[phasor]] + " voltage [MV]"
        elif plot_type == "streak_phase":
            data = np.transpose(np.angle(cavity_data["cavity_phasor_record"][:,:]))
            ylabel = labels[ph[phasor]] + " phase [rad]"
            
        fig, ax = plt.subplots()
        cmap = mpl.cm.cool
        c = ax.imshow(data, cmap=cmap, origin='lower' , aspect='auto')
        ax.set_xlabel("Bunch index")
        ax.set_ylabel("Number of turns")
        cbar = fig.colorbar(c, ax=ax)
        cbar.set_label(ylabel) 
    
    file.close()
    return fig
    return fig