# -*- coding: utf-8 -*-
"""
Module for plotting the data recorded by the monitor module during the 
tracking.

@author: Watanyu Foosang, Alexis Gamelin
@Date: 20/07/2021
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import h5py as hp
import random
from scipy.fft import rfftfreq

def plot_beamdata(filenames, dataset="mean", dimension="tau", stat_var="mean", 
                  x_var="time", turn=None, legend=None):
    """
    Plot 2D data recorded by BeamMonitor.

    Parameters
    ----------
    filenames : str or list of str
        Names of the HDF5 files to be plotted.
    dataset : {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted. The default is "mean".
    dimension : str
         The dimension of the dataset to plot:
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
            not used if "current".
        The default is "tau".
    stat_var : {"mean", "std"}
        Statistical value of the dimension.
    x_var : {"time", "index"}
        Choice of the horizontal axis:
            "time" corresponds to turn number.
            "index" corresponds to bunch index.
    turn : int or float, optional
        Choice of the turn to plot when x_var = "index".
        If None, the last turn is plotted.
    legend : list of str, optional
        Legend to add for each file.

    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    if isinstance(filenames, str):
        filenames = [filenames]
    
    fig, ax = plt.subplots()
    
    for filename in filenames:
        file = hp.File(filename, "r")
        data = file["Beam"]
        time = np.array(data["time"])
            
        if x_var == "time":
            x = time
            x_label = "Number of turns"
            bunch_index = data["current"][:,0] != 0
            
            if dataset == "current":
                y = np.nansum(data[dataset][bunch_index,:],0)*1e3
                y_label = "Total current (mA)"
            elif dataset == "emit":
                dimension_dict = {"x":0, "y":1, "s":2}
                axis = dimension_dict[dimension]
                label = ["$\\epsilon_{x}$ (m.rad)",
                         "$\\epsilon_{y}$ (m.rad)",
                         "$\\epsilon_{s}$ (m.rad)"]
                if stat_var == "mean":
                    y = np.nanmean(data[dataset][axis,bunch_index,:],0)
                elif stat_var == "std":
                    y = np.nanstd(data[dataset][axis,bunch_index,:],0)
                y_label = stat_var + " " + label[axis]
            elif dataset == "mean" or dataset == "std":
                dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
                                  "delta":5}
                axis = dimension_dict[dimension]
                scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
                label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
                         "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
                if stat_var == "mean":   
                    y = np.nanmean(data[dataset][axis,bunch_index,:],0)*scale[axis]
                elif stat_var == "std":      
                    y = np.nanstd(data[dataset][axis,bunch_index,:],0)*scale[axis]
                label_sup = {"mean":"mean of ", "std":"std of "}
                y_label = label_sup[stat_var] + dataset + " " + label[axis]
                
        elif x_var == "index":
            h = len(data["mean"][0,:,0])
            x = np.arange(h)
            x_label = "Bunch index"
            if turn is None:
                idx = -1
            else:
                idx = np.where(time == int(turn))
            if dataset == "current":
                y = data["current"][:,idx]*1e3
                y_label = "Bunch current (mA)"
            elif dataset == "emit":
                dimension_dict = {"x":0, "y":1, "s":2}
                axis = dimension_dict[dimension]
                label = ["$\\epsilon_{x}$ (m.rad)",
                         "$\\epsilon_{y}$ (m.rad)",
                         "$\\epsilon_{s}$ (m.rad)"]
                y = data["emit"][axis,:,idx]
                y_label = label[axis]
            elif dataset == "mean" or dataset == "std":
                dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
                                  "delta":5}
                axis = dimension_dict[dimension]
                scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
                label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
                         "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
                y = data[dataset][axis,:,idx]*scale[axis]
                y_label = dataset + " " + label[axis]
        else:
            raise ValueError("x_var should be time or index")
        
        ax.plot(x, y)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        if legend is not None:
            plt.legend(legend)
            
        file.close()
        
    return fig
            
def streak_beamdata(filename, dataset="mean", dimension="tau", cm_lim=None):
    """
    Plot 3D data recorded by BeamMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    dataset : {"current","emit","mean","std"}
        HDF5 file's dataset to be plotted. The default is "mean".
    dimension : str
         The dimension of the dataset to plot:
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"}.
            not used if "current".
        The default is "tau".
    cm_lim : list [vmin, vmax], optional
        Colormap scale for the "streak" plot.

    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    data = file["Beam"]
    time = np.array(data["time"])
        
    h = len(data["mean"][0,:,0])
    x = np.arange(h)
    x_label = "Bunch index"
    y = time
    y_label = "Number of turns"
    if dataset == "current":
        z = (np.array(data["current"])*1e3).T
        z_label = "Bunch current (mA)"
        title = z_label
    elif dataset == "emit":
        dimension_dict = {"x":0, "y":1, "s":2}
        axis = dimension_dict[dimension]
        label = ["$\\epsilon_{x}$ (m.rad)",
                 "$\\epsilon_{y}$ (m.rad)",
                 "$\\epsilon_{s}$ (m.rad)"]
        z = np.array(data["emit"][axis,:,:]).T
        z_label = label[axis]
        title = z_label
    else:
        dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, 
                              "delta":5}
        axis = dimension_dict[dimension]
        scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
        label = ["x (um)", "x' ($\\mu$rad)", "y (um)", 
                     "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
        z = np.array(data[dataset][axis,:,:]).T*scale[axis]
        z_label = label[axis]
        if dataset == "mean":
            title = label[axis] + " CM"
        elif dataset == "std":
            title = label[axis] + " RMS"
            
    fig, ax = plt.subplots()
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(title)
    
    if dataset == "mean":
        cmap = mpl.cm.coolwarm # diverging
    else:
        cmap = mpl.cm.inferno # sequential
    
    c = ax.imshow(z, cmap=cmap, origin='lower' , aspect='auto',
            extent=[x.min(), x.max(), y.min(), y.max()])
    if cm_lim is not None:
        c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
    cbar = fig.colorbar(c, ax=ax)
    cbar.set_label(z_label)
    
    file.close()
        
    return fig
              
def plot_bunchdata(filenames, bunch_number, dataset, dimension="x", 
                   legend=None):
    """
    Plot data recorded by BunchMonitor.
    
    Parameters
    ----------
    filenames : str or list of str
        Names of the HDF5 files to be plotted.
    bunch_number : int or list of int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
    dataset : {"current", "emit", "mean", "std", "cs_invariant"}
        HDF5 file's dataset to be plotted.
    dimension : str, optional
        The dimension of the dataset to plot. Use "None" for "current",
        otherwise use the following : 
            for "emit", dimension = {"x","y","s"},
            for "mean" and "std", dimension = {"x","xp","y","yp","tau","delta"},
            for "action", dimension = {"x","y"}.
    legend : list of str, optional
        Legend to add for each file.
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.

    """
    
    if isinstance(filenames, str):
        filenames = [filenames]
        
    if isinstance(bunch_number, int):
        ll = []
        for i in range(len(filenames)):
            ll.append(bunch_number)
        bunch_number = ll
        
    fig, ax = plt.subplots()
    
    for i, filename in enumerate(filenames):
        file = hp.File(filename, "r")
        group = "BunchData_{0}".format(bunch_number[i])  # Data group of the HDF5 file
        
        if dataset == "current":
            y_var = file[group][dataset][:]*1e3
            label = "current (mA)"
            
        elif dataset == "emit":
            dimension_dict = {"x":0, "y":1, "s":2}
                             
            y_var = file[group][dataset][dimension_dict[dimension]]*1e9
            
            if dimension == "x": label = "hor. emittance (nm.rad)"
            elif dimension == "y": label = "ver. emittance (nm.rad)"
            elif dimension == "s": label = "long. emittance (nm.rad)"
            
            
        elif dataset == "mean" or dataset == "std":                        
            dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5} 
            scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]        
            axis_index = dimension_dict[dimension]
            
            y_var = file[group][dataset][axis_index]*scale[axis_index]
            if dataset == "mean":
                label_list = ["x ($\\mu$m)", "x' ($\\mu$rad)", "y ($\\mu$m)",
                              "y' ($\\mu$rad)", "$\\tau$ (ps)", "$\\delta$"]
            else:
                label_list = ["$\\sigma_x$ ($\\mu$m)", "$\\sigma_{x'}$ ($\\mu$rad)",
                              "$\\sigma_y$ ($\\mu$m)", "$\\sigma_{y'}$ ($\\mu$rad)", 
                              "$\\sigma_{\\tau}$ (ps)", "$\\sigma_{\\delta}$"]
            
            label = label_list[axis_index]
            
        elif dataset == "cs_invariant":
            dimension_dict = {"x":0, "y":1}
            axis_index = dimension_dict[dimension]
            y_var = file[group][dataset][axis_index]
            label_list = ['$J_x$ (m)', '$J_y$ (m)']
            label = label_list[axis_index]

        x_axis = file[group]["time"][:]
        xlabel = "Number of turns"
            
        fig, ax = plt.subplots()        
        ax.plot(x_axis, y_var)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(label)
        if legend is not None:
            plt.legend(legend)
            
        file.close()
        
    return fig
            
def plot_phasespacedata(filename, bunch_number, x_var, y_var, turn,
                        only_alive=True, plot_size=1, plot_kind='kde'):
    """
    Plot data recorded by PhaseSpaceMonitor.

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'PhaseSpaceMonitor' object.
    x_var, y_var : str {"x", "xp", "y", "yp", "tau", "delta"}
        If dataset is "particles", the variables to be plotted on the 
        horizontal and the vertical axes need to be specified.
    turn : int
        Turn at which the data will be extracted.
    only_alive : bool, optional
        When only_alive is True, only alive particles are plotted and dead 
        particles will be discarded.
    plot_size : [0,1], optional
        Number of macro-particles to plot relative to the total number 
        of macro-particles recorded. This option helps reduce processing time
        when the data is big.
    plot_kind : {'scatter', 'kde', 'hex', 'reg', 'resid'}, optional
        The plot style. The default is 'kde'. 
        
    Return
    ------
    fig : Figure
        Figure object with the plot on it.
    """
    
    file = hp.File(filename, "r")
    
    group = "PhaseSpaceData_{0}".format(bunch_number)
    dataset = "particles"

    option_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e3,1e3,1e3,1e3,1e12,1]
    label = ["x (mm)","x' (mrad)","y (mm)","y' (mrad)","$\\tau$ (ps)",
             "$\\delta$"]
    
    # find the index of "turn" in time array
    turn_index = np.where(file["PhaseSpaceData_0"]["time"][:]==turn) 
    
    if len(turn_index[0]) == 0:
        raise ValueError("Turn {0} is not found. Enter turn from {1}.".
                         format(turn, file[group]["time"][:]))     
    
    path = file[group][dataset]
    mp_number = path[:,0,0].size
    
    if only_alive is True:
        index = np.where(file[group]["alive"][:,turn_index])[0]
    else:
        index = np.arange(mp_number)
        
    if plot_size == 1:
        samples = index
    elif plot_size < 1:
        samples_meta = random.sample(list(index), int(plot_size*mp_number))
        samples = sorted(samples_meta)
    else:
        raise ValueError("plot_size must be in range [0,1].")
            
    # format : sns.jointplot(x_axis, yaxis, kind)
    x_axis = path[samples,option_dict[x_var],turn_index[0][0]]
    y_axis = path[samples,option_dict[y_var],turn_index[0][0]]    
        
    fig = sns.jointplot(x_axis*scale[option_dict[x_var]], 
                        y_axis*scale[option_dict[y_var]], kind=plot_kind)
   
    plt.xlabel(label[option_dict[x_var]])
    plt.ylabel(label[option_dict[y_var]])
            
    file.close()
    return fig

def plot_profiledata(filename, bunch_number, dimension="tau", start=0,
                     stop=None, step=None, profile_plot=True, streak_plot=True):
    """
    Plot data recorded by ProfileMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'ProfileMonitor' object.
    dimension : str, optional
        Dimension to plot. The default is "tau"
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'ProfileMonitor' object, i.e. step % save_every == 0. If None, step is
        equivalent to save_every.
    profile_plot : bool, optional
        If Ture, bunch profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['ProfileData_{0}'.format(bunch_number)]
    l_bound = path["{0}_bin".format(dimension)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"x":0, "xp":1, "y":2, "yp":3, "tau":4, "delta":5}
    scale = [1e6, 1e6, 1e6, 1e6, 1e12, 1]
    label = ["x (um)", "x' ($\\mu$rad)", "y (um)", "y' ($\\mu$rad)",
             "$\\tau$ (ps)", "$\\delta$"]
    
    num = int((stop - start)/step)
    n_bin = len(path[dimension][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = l_bound[:,turn_index]
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*scale[dimension_dict[dimension]],
                    path[dimension][:,turn_index_array[i]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel(label[dimension_dict[dimension]])
        ax.set_ylabel("number of macro-particles")         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[dimension][:,turn_index_array])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*scale[dimension_dict[dimension]],
                               x_var.max()*scale[dimension_dict[dimension]],
                               y_var.min(),y_var.max()])
        ax2.set_xlabel(label[dimension_dict[dimension]])
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label("Number of macro-particles") 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_wakedata(filename, bunch_number, wake_type="Wlong", start=0,
                     stop=None, step=None, profile_plot=False, streak_plot=True,
                     bunch_profile=False, dipole=False):
    """
    Plot data recorded by WakePotentialMonitor

    Parameters
    ----------
    filename : str
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'WakePotentialMonitor' object.
    wake_type : str, optional
        Wake type to plot: "Wlong", "Wxdip", ... 
    start : int, optional
        First turn to plot. The default is 0.
    stop : int, optional
        Last turn to plot. If None, the last turn of the record is selected.
    step : int, optional
        Plotting step. This has to be divisible by 'save_every' parameter in
        'WakePotentialMonitor' object, i.e. step % save_every == 0. If None, 
        step is equivalent to save_every.
    profile_plot : bool, optional
        If Ture, wake potential profile plot is plotted.
    streak_plot : bool, optional
        If True, strek plot is plotted.
    bunch_profile : bool, optional.
        If True, the bunch profile is plotted.
    dipole : bool, optional
        If True, the dipole moment is plotted.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    path = file['WakePotentialData_{0}'.format(bunch_number)]
    
    if stop is None:
        stop = path['time'][-1]
    elif stop not in path['time']:
        raise ValueError("stop not found. Choose from {0}"
                         .format(path['time'][:]))
 
    if start not in path['time']:
        raise ValueError("start not found. Choose from {0}"
                         .format(path['time'][:]))
    
    save_every = path['time'][1] - path['time'][0]
    
    if step is None:
        step = save_every
    
    if step % save_every != 0:
        raise ValueError("step must be divisible by the recording step "
                         "which is {0}.".format(save_every))
    
    dimension_dict = {"Wlong":0, "Wxdip":1, "Wydip":2, "Wxquad":3, "Wyquad":4}
    scale = [1e-12, 1e-12, 1e-12, 1e-15, 1e-15]
    label = ["$W_p$ (V/pC)", "$W_{p,x}^D (V/pC)$", "$W_{p,y}^D (V/pC)$", "$W_{p,x}^Q (V/pC/mm)$",
             "$W_{p,y}^Q (V/pC/mm)$"]
    
    if bunch_profile == True:
        tau_name = "tau_" + wake_type
        wake_type = "profile_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["$\\rho$ (a.u.)"]
        
    elif dipole == True:
        tau_name = "tau_" + wake_type
        wake_type = "dipole_" + wake_type
        dimension_dict = {wake_type:0}
        scale = [1]
        label = ["Dipole moment (m)"]
        
    else:
        tau_name = "tau_" + wake_type
        
    num = int((stop - start)/step)
    n_bin = len(path[wake_type][:,0])
    
    start_index = np.where(path['time'][:] == start)[0][0]
    
    x_var = np.zeros((num+1,n_bin))
    turn_index_array = np.zeros((num+1,))
    for i in range(num+1):
        turn_index = start_index + i * step / save_every 
        turn_index_array[i] = turn_index
        # construct an array of bin mids
        x_var[i,:] = path[tau_name][:,turn_index]
        
    if profile_plot is True:
        fig, ax = plt.subplots()
        for i in range(num+1):
            ax.plot(x_var[i]*1e12,
                    path[wake_type][:,turn_index_array[i]]*scale[dimension_dict[wake_type]], 
                    label="turn {0}".format(path['time'][turn_index_array[i]]))
        ax.set_xlabel("$\\tau$ (ps)")
        ax.set_ylabel(label[dimension_dict[wake_type]])         
        ax.legend()
            
    if streak_plot is True:
        turn = np.reshape(path['time'][turn_index_array], (num+1,1))
        y_var = np.ones((num+1,n_bin)) * turn
        z_var = np.transpose(path[wake_type][:,turn_index_array]*scale[dimension_dict[wake_type]])
        fig2, ax2 = plt.subplots()
        cmap = mpl.cm.cool
        c = ax2.imshow(z_var, cmap=cmap, origin='lower' , aspect='auto',
                       extent=[x_var.min()*1e12,
                               x_var.max()*1e12,
                               y_var.min(),y_var.max()])
        ax2.set_xlabel("$\\tau$ (ps)")
        ax2.set_ylabel("Number of turns")
        cbar = fig2.colorbar(c, ax=ax2)
        cbar.set_label(label[dimension_dict[wake_type]]) 

    file.close()
    if profile_plot is True and streak_plot is True:
        return fig, fig2
    elif profile_plot is True:
        return fig
    elif streak_plot is True:
        return fig2
    
def plot_tunedata(filename, bunch_number, f0, plot_tune=True, plot_fft=False,
                  dimension='x', min_tune=0, max_tune=0.5, min_turn=None, 
                  max_turn=None, streak_plot=True, profile_plot=False):
    """
    Plot data recorded by TuneMonitor.
    
    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    bunch_number : int
        Bunch to plot. This has to be identical to 'bunch_number' parameter in 
        'BunchMonitor' object.
    f0 : float
        Revolution frequency of the ring used for the tracking in [Hz].
    plot_tune : bool, optional
        If True, tune data usinf NAFF is plotted.
    plot_fft : bool, optional
        If True, FFT data is plotted.
    dimension : {'x', 'y', 's'}
        Option to plot FFT data in horizontal, vertical, or longitudinal plane.
    min_tune, max_tune : int or float, optional
        The minimum and the maximum tune values to plot FFT data.
    min_turn, max_turn : int or float, optional
        The minimum and the maximum number of turns to plot FFT data.
    streak_plot : bool, optional
        If True, the FFT data is plotted as a streak plot.
    profile_plot : bool, optional.
        If True, the FFT data is plotted as line plots.
        
    Return
    ------
    fig : Figure
        Figure objects with the plot on it.

    """
    
    file = hp.File(filename, "r")
    
    group = "TuneData_{0}".format(bunch_number)  # Data group of the HDF5 file
    time = file[group]["time"]
    
    fig_to_return = []
    
    if plot_tune is True:
        tune = file[group]["tune"]
        tune_spread = file[group]["tune_spread"]
            
        fig1, ax1 = plt.subplots()        
        ax1.errorbar(x=time, y=tune[0,:], yerr=tune_spread[0,:])
        ax1.errorbar(x=time, y=tune[1,:], yerr=tune_spread[1,:])
        ax1.set_xlabel("Turn number")
        ax1.set_ylabel("Transverse tunes")
        plt.legend(["x","y"])
        fig_to_return.append(fig1)
        
        fig2, ax2 = plt.subplots()        
        ax2.errorbar(x=time, y=tune[2,:], yerr=tune_spread[2,:])
        ax2.set_xlabel("Turn number")
        ax2.set_ylabel("Synchrotron tune")
        fig_to_return.append(fig2)
        
    if plot_fft is True:
        
        T0 = 1/f0
        n_freq = file[group]['fft'].shape[1]
        freq = rfftfreq((n_freq-1)*2, T0)
        tune_fft = freq / f0
        
        dimension_dict = {'x':0, 'y':1, 's':2}
        axis = dimension_dict[dimension]
        
        fourier_save = file[group]['fft'][axis]
        
        if max_turn is None:
            max_turn = time[-1]
        if min_turn is None:
            min_turn = time[0]
            
        min_tune_iloc = np.where(tune_fft >= min_tune)[0][0]
        max_tune_iloc = np.where(tune_fft <= max_tune)[0][-1]
        save_every = int(time[1] - time[0])
        min_turn_iloc = min_turn // save_every - 1
        max_turn_iloc = max_turn // save_every
        
    
        if streak_plot is True:
            fig3, ax3 = plt.subplots()
            cmap = plt.get_cmap('Blues')
        
            c = ax3.imshow(np.transpose(np.log(
                          fourier_save[min_tune_iloc:max_tune_iloc+1, 
                                       min_turn_iloc:max_turn_iloc+1])),
                          cmap=cmap, origin='lower' , aspect='auto',
                          extent=[min_tune, max_tune, min_turn, max_turn])
            ax3.set_xlabel('$Q_{}$'.format(dimension))
            ax3.set_ylabel("Turns")
            cbar = fig3.colorbar(c, ax=ax3)
            cbar.set_label("log FFT amplitude") 
            fig_to_return.append(fig3)
            
        if profile_plot is True:
            fig4, ax4 = plt.subplots()
            ax4.plot(tune_fft[min_tune_iloc:max_tune_iloc+1], 
                     fourier_save[min_tune_iloc:max_tune_iloc+1,
                                  min_turn_iloc:max_turn_iloc+1])
            ax4.set_xlabel('$Q_{}$'.format(dimension))
            ax4.set_ylabel("FFT amplitude")
            ax4.legend(time[min_turn_iloc:max_turn_iloc+1])
            fig_to_return.append(fig4)
            
    file.close()
   
    return tuple(fig_to_return)

def plot_cavitydata(filename, cavity_name, phasor="cavity", 
                    plot_type="bunch", bunch_number=0, turn=None, cm_lim=None):
    """
    Plot data recorded by CavityMonitor.

    Parameters
    ----------
    filename : str 
        Name of the HDF5 file that contains the data.
    cavity_name : str
        Name of the CavityResonator object.
    phasor : str, optional
        Type of the phasor to plot. Can be "beam" or "cavity".
    plot_type : str, optional
        Type of plot:
            - "bunch" plots the phasor voltage and angle versus time for a 
            given bunch.
            - "turn" plots the phasor voltage and ange versus bunch index for
            a given turn.
            - "streak_volt" plots the phasor voltage versus bunch index and 
            time
            - "streak_angle" plots the phasor angle versus bunch index and 
            time
    bunch_number : int, optional
        Bunch number to select. The default is 0.
    turn : int, optional
        Turn to plot. The default is None.
    cm_lim : list [vmin, vmax], optional
        Colormap scale for the "streak" plots.

    Returns
    -------
    fig : Figure
        Figure object with the plot on it.

    """
    
    file = hp.File(filename, "r")
    cavity_data = file[cavity_name]
    
    time = cavity_data["time"]
    
    ph = {"cavity":0, "beam":1}
    labels = ["Cavity", "Beam"]
    
    if plot_type == "bunch":
    
        data = [cavity_data["cavity_phasor_record"][bunch_number,:], 
                cavity_data["beam_phasor_record"][bunch_number,:]]

        ylabel1 = labels[ph[phasor]] + " voltage [MV]"
        ylabel2 = labels[ph[phasor]] + " phase [rad]"
        
        fig, ax = plt.subplots()
        twin = ax.twinx()
        p1, = ax.plot(time, np.abs(data[ph[phasor]])*1e-6, color="r",label=ylabel1)
        p2, = twin.plot(time, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
        ax.set_xlabel("Turn number")
        ax.set_ylabel(ylabel1)
        twin.set_ylabel(ylabel2)
        
        plots = [p1, p2]
        ax.legend(handles=plots, loc="best")
        
        ax.yaxis.label.set_color("r")
        twin.yaxis.label.set_color("b")
        
    if plot_type == "turn":
        
        index = np.array(time) == turn
        ph = {"cavity":0, "beam":1}
        data = [cavity_data["cavity_phasor_record"][:,index], 
                cavity_data["beam_phasor_record"][:,index]]
        labels = ["Cavity", "Beam"]
        
        h=len(data[0])
        x=np.arange(h)

        ylabel1 = labels[ph[phasor]] + " voltage [MV]"
        ylabel2 = labels[ph[phasor]] + " phase [rad]"
        
        fig, ax = plt.subplots()
        twin = ax.twinx()
        p1, = ax.plot(x, np.abs(data[ph[phasor]])*1e-6, color="r",label=ylabel1)
        p2, = twin.plot(x, np.angle(data[ph[phasor]]), color="b", label=ylabel2)
        ax.set_xlabel("Bunch index")
        ax.set_ylabel(ylabel1)
        twin.set_ylabel(ylabel2)
        
        plots = [p1, p2]
        ax.legend(handles=plots, loc="best")
        
        ax.yaxis.label.set_color("r")
        twin.yaxis.label.set_color("b")
        
    if plot_type == "streak_volt" or plot_type == "streak_phase":
        
        if plot_type == "streak_volt":
            data = np.transpose(np.abs(cavity_data["cavity_phasor_record"][:,:])*1e-6)
            ylabel = labels[ph[phasor]] + " voltage [MV]"
        elif plot_type == "streak_phase":
            data = np.transpose(np.angle(cavity_data["cavity_phasor_record"][:,:]))
            ylabel = labels[ph[phasor]] + " phase [rad]"
            
        fig, ax = plt.subplots()
        cmap = mpl.cm.cool
        c = ax.imshow(data, cmap=cmap, origin='lower' , aspect='auto')
        if cm_lim is not None:
            c.set_clim(vmin=cm_lim[0],vmax=cm_lim[1])
        ax.set_xlabel("Bunch index")
        ax.set_ylabel("Number of turns")
        cbar = fig.colorbar(c, ax=ax)
        cbar.set_label(ylabel) 
    
    file.close()
    return fig