# -*- coding: utf-8 -*-
"""
This module defines the different monitor class which are used to save data
during tracking.

@author: Alexis Gamelin
@date: 24/03/2020

"""

import numpy as np
import h5py as hp
import PyNAFF as pnf
import random
from mbtrack2.tracking.element import Element
from mbtrack2.tracking.particles import Bunch, Beam
from mbtrack2.tracking.rf import CavityResonator
from scipy.interpolate import interp1d
from abc import ABCMeta
from scipy.fft import rfft, rfftfreq
from mpi4py import MPI

class Monitor(Element, metaclass=ABCMeta):
    """
    Abstract Monitor class used for subclass inheritance to define all the
    different kind of monitors objects. 
    
    The Monitor class is based on h5py module to be able to write data on 
    structured binary files. The class provides a common file where the 
    different Monitor subclass can write.
    
    Attributes
    ----------
    file : HDF5 file
        Common file where all monitors, Monitor subclass elements, write the
        saved data. Based on class attribute _file_storage.
    file_name : string
        Name of the HDF5 file where the data is stored. Based on class 
        attribute _file_name_storage.
        
    Methods
    -------
    monitor_init(group_name, save_every, buffer_size, total_size,
                     dict_buffer, dict_file, file_name=None, mpi_mode=True)
        Method called to initialize Monitor subclass.
    write()
        Write data from buffer to the HDF5 file.
    to_buffer(object_to_save)
        Save data to buffer.
    close()
        Close the HDF5 file shared by all Monitor subclass, must be called 
        by at least an instance of a Montior subclass at the end of the 
        tracking.
    track_bunch_data(object_to_save)
        Track method to use when saving bunch data.
    """
    
    _file_name_storage = []
    _file_storage = []
    
    @property
    def file_name(self):
        """Common file where all monitors, Monitor subclass elements, write the
        saved data."""
        try:
            return self._file_name_storage[0]
        except IndexError:
            print("The HDF5 file name for monitors is not set.")
            raise ValueError
            
    @property
    def file(self):
        """Name of the HDF5 file where the data is stored."""
        try:
            return self._file_storage[0]
        except IndexError:
            print("The HDF5 file to store data is not set.")
            raise ValueError
            
    def monitor_init(self, group_name, save_every, buffer_size, total_size,
                     dict_buffer, dict_file, file_name=None, mpi_mode=True):
        """
        Method called to initialize Monitor subclass. 
        
        Parameters
        ----------
        group_name : string
            Name of the HDF5 group in which the data for the current monitor 
            will be saved.
        save_every : int or float
            Set the frequency of the save. The data is saved every save_every 
            call of the montior.
        buffer_size : int or float
            Size of the save buffer.
        total_size : int or float
            Total size of the save. The following relationships between the 
            parameters must exist: 
                total_size % buffer_size == 0
                number of call to track / save_every == total_size
        dict_buffer : dict
            Dictionary with keys as the attribute name to save and values as
            the shape of the buffer to create to hold the attribute, like 
            (key.shape, buffer_size)
        dict_file : dict
            Dictionary with keys as the attribute name to save and values as
            the shape of the dataset to create to hold the attribute, like 
            (key.shape, total_size)
        file_name : string, optional
            Name of the HDF5 where the data will be stored. Must be specified
            the first time a subclass of Monitor is instancied and must be None
            the following times.
        mpi_mode : bool, optional
            If True, open the HDF5 file in parallel mode, which is needed to
            allow several cores to write in the same file at the same time.
            If False, open the HDF5 file in standard mode.
        """
        
        # setup and open common file for all monitors
        if file_name is not None:
            if len(self._file_name_storage) == 0:
                self._file_name_storage.append(file_name + ".hdf5")
                if len(self._file_storage) == 0:
                    if mpi_mode == True:
                        f = hp.File(self.file_name, "w", libver='latest', 
                             driver='mpio', comm=MPI.COMM_WORLD)
                    else:
                        f = hp.File(self.file_name, "w", libver='latest')
                    self._file_storage.append(f)
                else:
                    raise ValueError("File is already open.")
            else:
                raise ValueError("File name for monitors is already attributed.")
        
        self.group_name = group_name
        self.save_every = int(save_every)
        self.total_size = int(total_size)
        self.buffer_size = int(buffer_size)
        if total_size % buffer_size != 0:
            raise ValueError("total_size must be divisible by buffer_size.")
        self.buffer_count = 0
        self.write_count = 0
        self.track_count = 0
        
        # setup attribute buffers from values given in dict_buffer
        for key, value in dict_buffer.items():
            self.__setattr__(key,np.zeros(value))
        self.time = np.zeros((self.buffer_size,), dtype=int)

        # create HDF5 groups and datasets to save data from group_name and 
        # dict_file
        self.g = self.file.require_group(self.group_name)
        self.g.require_dataset("time", (self.total_size,), dtype=int)
        for key, value in dict_file.items():
            self.g.require_dataset(key, value, dtype=float)
        
        # create a dictionary which handle slices
        slice_dict = {}
        for key, value in dict_file.items():
            slice_dict[key] = []
            for i in range(len(value)-1):
                slice_dict[key].append(slice(None))
        self.slice_dict = slice_dict
        
    def write(self):
        """Write data from buffer to the HDF5 file."""
        
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time
        for key, value in self.dict_buffer.items():
            slice_list = list(self.slice_dict[key])
            slice_list.append(slice(self.write_count*self.buffer_size,
                                    (self.write_count+1)*self.buffer_size))
            slice_tuple = tuple(slice_list)
            self.file[self.group_name][key][slice_tuple] = self.__getattribute__(key)
        
        self.file.flush()
        self.write_count += 1
        
    def to_buffer(self, object_to_save):
        """
        Save data to buffer.
        
        Parameters
        ----------
        object_to_save : python object
            Depends on the Monitor subclass, typically a Beam or Bunch object.
        """
        self.time[self.buffer_count] = self.track_count
        for key, value in self.dict_buffer.items():
            slice_list = list(self.slice_dict[key])
            slice_list.append(self.buffer_count)
            slice_tuple = tuple(slice_list)
            self.__getattribute__(key)[slice_tuple] = object_to_save.__getattribute__(key)
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write()
            self.buffer_count = 0
            
    def close(self):
        """
        Close the HDF5 file shared by all Monitor subclass, must be called 
        by at least an instance of a Montior subclass at the end of the 
        tracking.
        """
        try:
            self.file.close()
        except ValueError:
            pass
        
    def track_bunch_data(self, object_to_save):
        """
        Track method to use when saving bunch data.
        
        Parameters
        ----------
        object_to_save : Beam or Bunch
        """
        if self.track_count % self.save_every == 0:
            if isinstance(object_to_save, Beam):
                if (object_to_save.mpi_switch == True):
                    if object_to_save.mpi.bunch_num == self.bunch_number:
                        self.to_buffer(object_to_save[object_to_save.mpi.bunch_num])
                else:
                    self.to_buffer(object_to_save[self.bunch_number])
            elif isinstance(object_to_save, Bunch):
                self.to_buffer(object_to_save)
            else:                            
                raise TypeError("object_to_save should be a Beam or Bunch object.")
        self.track_count += 1
        
            
class BunchMonitor(Monitor):
    """
    Monitor a single bunch and save attributes 
    (mean, std, emit, current, and cs_invariant).
    
    Parameters
    ----------
    bunch_number : int
        Bunch to monitor
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(object_to_save)
        Save data
    """
    
    def __init__(self, bunch_number, file_name=None, save_every=5,
                 buffer_size=500, total_size=2e4, mpi_mode=True):
        
        self.bunch_number = bunch_number
        group_name = "BunchData_" + str(self.bunch_number)
        dict_buffer = {"mean":(6, buffer_size), "std":(6, buffer_size),
                     "emit":(3, buffer_size), "current":(buffer_size,),
                     "cs_invariant":(2, buffer_size)}
        dict_file = {"mean":(6, total_size), "std":(6, total_size),
                     "emit":(3, total_size), "current":(total_size,),
                     "cs_invariant":(2, total_size)}
        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
                    
    def track(self, object_to_save):
        """
        Save data
        
        Parameters
        ----------
        object_to_save : Bunch or Beam object
        """        
        self.track_bunch_data(object_to_save)
        
class PhaseSpaceMonitor(Monitor):
    """
    Monitor a single bunch and save the full phase space.
    
    Parameters
    ----------
    bunch_number : int
        Bunch to monitor
    mp_number : int or float
        Number of macroparticle in the phase space to save. If less than the 
        total number of macroparticles, a random fraction of the bunch is saved.
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(object_to_save)
        Save data
    """
    
    def __init__(self, bunch_number, mp_number, file_name=None, save_every=1e3,
                 buffer_size=10, total_size=100, mpi_mode=True):
        
        self.bunch_number = bunch_number
        self.mp_number = int(mp_number)
        group_name = "PhaseSpaceData_" + str(self.bunch_number)
        dict_buffer = {"particles":(self.mp_number, 6, buffer_size), 
                       "alive":(self.mp_number, buffer_size)}
        dict_file = {"particles":(self.mp_number, 6, total_size),
                     "alive":(self.mp_number, total_size)}
        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
                    
    def track(self, object_to_save):
        """
        Save data
        
        Parameters
        ----------
        object_to_save : Bunch or Beam object
        """        
        self.track_bunch_data(object_to_save)
        
    def to_buffer(self, bunch):
        """
        Save data to buffer.
        
        Parameters
        ----------
        bunch : Bunch object
        """
        self.time[self.buffer_count] = self.track_count
        
        if len(bunch.alive) != self.mp_number:
            index = np.arange(len(bunch.alive))
            samples_meta = random.sample(list(index), self.mp_number)
            samples = sorted(samples_meta)
        else:
            samples = slice(None)

        self.alive[:, self.buffer_count] = bunch.alive[samples]
        for i, dim in enumerate(bunch):
            self.particles[:, i, self.buffer_count] = bunch.particles[dim][samples]
        
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write()
            self.buffer_count = 0
        

            
class BeamMonitor(Monitor):
    """
    Monitor the full beam and save each bunch attributes (mean, std, emit and 
    current).
    
    Parameters
    ----------
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(beam)
        Save data    
    """
    
    def __init__(self, h, file_name=None, save_every=5, buffer_size=500, 
                 total_size=2e4, mpi_mode=True):
        
        group_name = "Beam"
        dict_buffer = {"mean" : (6, h, buffer_size), 
                       "std" : (6, h, buffer_size),
                       "emit" : (3, h, buffer_size),
                       "current" : (h, buffer_size)}
        dict_file = {"mean" : (6, h, total_size), 
                       "std" : (6, h, total_size),
                       "emit" : (3, h, total_size),
                       "current" : (h, total_size)}
        
        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
                    
    def track(self, beam):
        """
        Save data
        
        Parameters
        ----------
        beam : Beam object
        """     
        if self.track_count % self.save_every == 0:
            if (beam.mpi_switch == True):
                self.to_buffer(beam[beam.mpi.bunch_num], beam.mpi.bunch_num)
            else:
                self.to_buffer_no_mpi(beam)
                    
        self.track_count += 1
        
    def to_buffer(self, bunch, bunch_num):
        """
        Save data to buffer, if mpi is being used.
        
        Parameters
        ----------
        bunch : Bunch object
        bunch_num : int
        """
        
        self.time[self.buffer_count] = self.track_count
        self.mean[:, bunch_num, self.buffer_count] = bunch.mean
        self.std[:, bunch_num, self.buffer_count] = bunch.std
        self.emit[:, bunch_num, self.buffer_count] = bunch.emit
        self.current[bunch_num, self.buffer_count] = bunch.current
        
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write(bunch_num)
            self.buffer_count = 0
            
    def to_buffer_no_mpi(self, beam):
        """
        Save data to buffer, if mpi is not being used.
        
        Parameters
        ----------
        beam : Beam object
        """
              
        self.time[self.buffer_count] = self.track_count
        self.mean[:, :, self.buffer_count] = beam.bunch_mean
        self.std[:, :, self.buffer_count] = beam.bunch_std
        self.emit[:, :, self.buffer_count] = beam.bunch_emit
        self.current[:, self.buffer_count] = beam.bunch_current
        
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write_no_mpi()
            self.buffer_count = 0

    def write(self, bunch_num):
        """
        Write data from buffer to the HDF5 file, if mpi is being used.
        
        Parameters
        ----------
        bunch_num : int
        """
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time
    
        self.file[self.group_name]["mean"][:, bunch_num, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.mean[:, bunch_num, :]
                 
        self.file[self.group_name]["std"][:, bunch_num, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.std[:, bunch_num, :]

        self.file[self.group_name]["emit"][:, bunch_num, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.emit[:, bunch_num, :]

        self.file[self.group_name]["current"][bunch_num, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.current[bunch_num, :]
                 
        self.file.flush() 
        self.write_count += 1
        
    def write_no_mpi(self):
        """
        Write data from buffer to the HDF5 file, if mpi is not being used.
        """
        
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time
    
        self.file[self.group_name]["mean"][:, :, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.mean
                 
        self.file[self.group_name]["std"][:, :, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.std

        self.file[self.group_name]["emit"][:, :, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.emit

        self.file[self.group_name]["current"][:, 
                 self.write_count*self.buffer_size:(self.write_count+1) * 
                 self.buffer_size] = self.current
                 
        self.file.flush() 
        self.write_count += 1

        
class ProfileMonitor(Monitor):
    """
    Monitor a single bunch and save bunch profiles.
    
    Parameters
    ----------
    bunch_number : int
        Bunch to monitor.
    dimensions : str or list of str, optional
        Dimensions to save.
    n_bin : int or list of int, optional
        Number of bin to use in each dimension.
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(object_to_save)
        Save data.
    """
    
    def __init__(self, bunch_number, dimensions="tau", n_bin=75, file_name=None, 
                 save_every=5, buffer_size=500, total_size=2e4, mpi_mode=True):
        
        self.bunch_number = bunch_number
        group_name = "ProfileData_" + str(self.bunch_number)
        
        if isinstance(dimensions, str):
            self.dimensions = [dimensions]
        else:
            self.dimensions = dimensions
            
        if isinstance(n_bin, int):
            self.n_bin = np.ones((len(self.dimensions),), dtype=int)*n_bin
        else:
            self.n_bin = n_bin
        
        dict_buffer = {}
        dict_file = {}
        for index, dim in enumerate(self.dimensions):
            dict_buffer.update({dim : (self.n_bin[index], buffer_size)})
            dict_buffer.update({dim + "_bin" : (self.n_bin[index] + 1, buffer_size)})
            dict_file.update({dim : (self.n_bin[index], total_size)})
            dict_file.update({dim + "_bin" : (self.n_bin[index] + 1, total_size)})

        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
        
    def to_buffer(self, bunch):
        """
        Save data to buffer.
        
        Parameters
        ----------
        bunch : Bunch object
        """

        self.time[self.buffer_count] = self.track_count
        for index, dim in enumerate(self.dimensions):
            bins, sorted_index, profile = bunch.binning(dim, self.n_bin[index])
            bin_array = np.append(np.array(bins.left), bins.right[-1])
            profile_array = np.array(profile)
            self.__getattribute__(dim + "_bin")[:, self.buffer_count] = bin_array
            self.__getattribute__(dim)[:, self.buffer_count] = profile_array
        
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write()
            self.buffer_count = 0
            
    def write(self):
        """Write data from buffer to the HDF5 file."""
        
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time

        for dim in self.dimensions:
            self.file[self.group_name][dim][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__(dim)
            self.file[self.group_name][dim + "_bin"][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__(dim + "_bin")
            
        self.file.flush()
        self.write_count += 1
                    
    def track(self, object_to_save):
        """
        Save data.
        
        Parameters
        ----------
        object_to_save : Bunch or Beam object
        """        
        self.track_bunch_data(object_to_save)
        
class WakePotentialMonitor(Monitor):
    """
    Monitor the wake potential from a single bunch and save attributes (tau, 
    ...).
    
    Parameters
    ----------
    bunch_number : int
        Bunch to monitor.
    wake_types : str or list of str
        Wake types to save: "Wlong, "Wxdip", ...
    n_bin : int
        Number of bin to be used to interpolate the wake potential on a fixed
        grid.
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(wake_potential_to_save)
        Save data.
    """
    
    def __init__(self, bunch_number, wake_types, n_bin, file_name=None, 
                 save_every=5, buffer_size=500, total_size=2e4, mpi_mode=True):
        
        self.bunch_number = bunch_number
        group_name = "WakePotentialData_" + str(self.bunch_number)
        
        if isinstance(wake_types, str):
            self.wake_types = [wake_types]
        else:
            self.wake_types = wake_types
            
        self.n_bin = n_bin
        
        dict_buffer = {}
        dict_file = {}
        for index, dim in enumerate(self.wake_types):
            dict_buffer.update({"tau_" + dim : (self.n_bin, buffer_size)})
            dict_file.update({"tau_" + dim : (self.n_bin, total_size)})
            dict_buffer.update({"profile_" + dim : (self.n_bin, buffer_size)})
            dict_file.update({"profile_" + dim : (self.n_bin, total_size)})
            dict_buffer.update({dim : (self.n_bin, buffer_size)})
            dict_file.update({dim : (self.n_bin, total_size)})
            if dim == "Wxdip" or dim == "Wydip":
                dict_buffer.update({"dipole_" + dim : (self.n_bin, buffer_size)})
                dict_file.update({"dipole_" + dim : (self.n_bin, total_size)})

        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
        
    def to_buffer(self, wp):
        """
        Save data to buffer.
        
        Parameters
        ----------
        wp : WakePotential object
        """

        self.time[self.buffer_count] = self.track_count
        for index, dim in enumerate(self.wake_types):
            tau0 = wp.__getattribute__("tau0_" + dim)
            profile0 = wp.__getattribute__("profile0_" + dim)
            WP0 = wp.__getattribute__(dim)
            if dim == "Wxdip":
                dipole0 = wp.__getattribute__("dipole_x")
            if dim == "Wydip":
                dipole0 = wp.__getattribute__("dipole_y")
            
            tau = np.linspace(tau0[0], tau0[-1], self.n_bin)
            f = interp1d(tau0, WP0, fill_value = 0, bounds_error = False)
            WP = f(tau)
            g = interp1d(tau0, profile0, fill_value = 0, bounds_error = False)
            profile = g(tau)
            if dim == "Wxdip" or dim == "Wydip":
                h = interp1d(tau0, dipole0, fill_value = 0, bounds_error = False)
                dipole = h(tau)
            
            self.__getattribute__("tau_" + dim)[:, self.buffer_count] = tau + wp.tau_mean
            self.__getattribute__("profile_" + dim)[:, self.buffer_count] = profile
            self.__getattribute__(dim)[:, self.buffer_count] = WP
            if dim == "Wxdip" or dim == "Wydip":
                self.__getattribute__("dipole_" + dim)[:, self.buffer_count] = dipole
            
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write()
            self.buffer_count = 0
            
    def write(self):
        """Write data from buffer to the HDF5 file."""
        
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time
        
        for dim in self.wake_types:
            self.file[self.group_name]["tau_" + dim][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__("tau_" + dim)
            self.file[self.group_name]["profile_" + dim][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__("profile_" + dim)
            self.file[self.group_name][dim][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__(dim)
            if dim == "Wxdip" or dim == "Wydip":
                self.file[self.group_name]["dipole_" + dim][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.__getattribute__("dipole_" + dim)
            
        self.file.flush()
        self.write_count += 1
                    
    def track(self, wake_potential_to_save):
        """
        Save data.
        
        Parameters
        ----------
        object_to_save : WakePotential object
        """        
        if self.track_count % self.save_every == 0:
            self.to_buffer(wake_potential_to_save)
        self.track_count += 1

class TuneMonitor(Monitor):
    """
    Monitor tunes and the Fourier transform (using FFT algorithm) of the 
    osciallation in horizontal, vertical, and longitudinal plane. 
    
    Parameters
    ----------
    ring : Synchrotron object
    bunch_number : int
        Bunch to monitor
    mp_number : int or float
        Total number of macro-particles in the bunch.
    sample_size : int or float
        Number of macro-particles to be used for tune and FFT computation.
        This number cannot exceed mp_number.
    save_tune : bool, optional
        If True, tune data is saved. 
    save_fft : bool, optional
        If True, FFT data is saved.    
    n_fft : int or float, optional
        The number of points used for FFT computation.
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The tune is computed saved every 
        save_every call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size    
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.
        
    Methods
    -------
    track(bunch):
        Save tune and/or FFT data.
    
    """
    
    def __init__(self, ring, bunch_number, mp_number, sample_size, save_tune=True,
                 save_fft=False, n_fft=10000, file_name=None, save_every=10, 
                 buffer_size=5, total_size=10, mpi_mode=True):
        
        self.ring = ring
        self.bunch_number = bunch_number
        group_name = "TuneData_" + str(self.bunch_number)
        
        dict_buffer = {"tune":(3, buffer_size), "tune_spread":(3, buffer_size,),
                       "fft":(3, n_fft//2+1, buffer_size)}
        dict_file = {"tune":(3, total_size), "tune_spread":(3, total_size,),
                     "fft":(3, n_fft//2+1, total_size)}
        
        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
        
        self.sample_size = int(sample_size)
        self.x = np.zeros((self.sample_size, save_every+1))
        self.y = np.zeros((self.sample_size, save_every+1))
        self.tau = np.zeros((self.sample_size, save_every+1))
        
        index = np.arange(0, int(mp_number))
        self.index_sample = sorted(random.sample(list(index), self.sample_size))
        
        self.save_count = 0
        self.buffer_count = 1
        
        self.save_tune = save_tune
        self.save_fft = save_fft
        self.save_every = save_every
        
        if self.save_fft is True :
            self.n_fft = n_fft 
            self.fourier_save = np.zeros((3, self.n_fft//2+1, buffer_size))
        
    def track(self, object_to_save):
        """
        Save tune data.

        Parameters
        ----------
        object_to_save : Beam or Bunch object

        """
        skip = False
        if isinstance(object_to_save, Beam):
            if (object_to_save.mpi_switch == True):
                if object_to_save.mpi.bunch_num == self.bunch_number:
                    bunch = object_to_save[object_to_save.mpi.bunch_num]
                else:
                    skip = True
            else:
                bunch = object_to_save[self.bunch_number]
        elif isinstance(object_to_save, Bunch):
            bunch = object_to_save
        else:
            raise TypeError("object_to_save should be a Beam or Bunch object.")
        
        if skip is not True:
            self.x[:, self.save_count] = bunch["x"][self.index_sample]
            self.y[:, self.save_count] = bunch["y"][self.index_sample]
            self.tau[:, self.save_count] = bunch["tau"][self.index_sample]
            
            self.save_count += 1
            
            if self.track_count > 0 and self.track_count % self.save_every == 0:
                self.to_buffer(bunch)
                self.save_count = 0
    
            self.track_count += 1          
        
    def to_buffer(self, bunch):
        """
        A method to hold saved data before writing it to the output file.

        """
        
        self.time[self.buffer_count] = self.track_count
        
        if self.save_tune is True:
            mean, spread = self.get_tune(bunch)
            self.tune[:, self.buffer_count] = mean
            self.tune_spread[:, self.buffer_count] = spread
        
        if self.save_fft is True:
            fx, fy, ftau = self.get_fft()
            self.fourier_save[0,:,self.buffer_count] = fx
            self.fourier_save[1,:,self.buffer_count] = fy
            self.fourier_save[2,:,self.buffer_count] = ftau
        
        self.buffer_count += 1
        
        if self.buffer_count == self.buffer_size:
            self.write()
            self.buffer_count = 0
            
    def write(self):
        """
        Write data from buffer to output file.

        """
        self.file[self.group_name]["time"][self.write_count*self.buffer_size:(
                    self.write_count+1)*self.buffer_size] = self.time

        if self.save_tune is True:
            self.file[self.group_name]["tune"][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.tune
            self.file[self.group_name]["tune_spread"][:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.tune_spread
            
        if self.save_fft is True:
            self.file[self.group_name]["fft"][:,:, 
                    self.write_count * self.buffer_size:(self.write_count+1) * 
                    self.buffer_size] = self.fourier_save
            
        self.file.flush()
        self.write_count += 1

    def get_tune(self, bunch):
        """
        Compute tune by using NAFF algorithm to indentify the fundamental
        harmonic frequency of the particles' motion.

        Parameters
        ----------
        bunch : Bunch object

        """
        
        turn = self.save_every
        freq = np.zeros((self.sample_size,3))
        
        for i in range(self.sample_size):
            try:
                freq[i,0] = pnf.naff(self.x[i,:], turns=turn-1, nterms=1)[0][1] \
                                / self.ring.T0
            except IndexError:
                freq[i,0] = np.nan
               
            try:
                freq[i,1] = pnf.naff(self.y[i,:], turns=turn-1, nterms=1)[0][1] \
                                / self.ring.T0
            except IndexError:
                freq[i,1] = np.nan

            try:              
                freq[i,2] = pnf.naff(self.tau[i,:], turns=turn-1, nterms=1)[0][1] \
                                / self.ring.T0
            except IndexError:
                freq[i,2] = np.nan
        
        tune_single_particle = freq / self.ring.f0
        mean = np.nanmean(tune_single_particle, 0)
        spread = np.nanstd(tune_single_particle, 0)
        
        return (mean, spread)
    
    def get_fft(self):
        """
        Compute the Fourier transform (using FFT algorithm) of the 
        osciallation in horizontal, vertical, and longitudinal plane. 

        Returns
        -------
        fourier_x_avg, fourier_y_avg, fourier_tau_avg : ndarray
            The average of the transformed input in each plane.

        """
        fourier_x = rfft(self.x, n=self.n_fft)
        fourier_y = rfft(self.y, n=self.n_fft)
        fourier_tau = rfft(self.tau, n=self.n_fft)
        
        fourier_x_avg =  np.mean(abs(fourier_x),axis=0)
        fourier_y_avg =  np.mean(abs(fourier_y),axis=0)
        fourier_tau_avg =  np.mean(abs(fourier_tau),axis=0)
        
        return (fourier_x_avg, fourier_y_avg, fourier_tau_avg)
        
class CavityMonitor(Monitor):
    """
    Monitor a CavityResonator object and save attributes (mean, std, emit and current).
    
    Parameters
    ----------
    cavity_name : str
        Name of the CavityResonator object to monitor.
    file_name : string, optional
        Name of the HDF5 where the data will be stored. Must be specified
        the first time a subclass of Monitor is instancied and must be None
        the following times.
    save_every : int or float, optional
        Set the frequency of the save. The data is saved every save_every 
        call of the montior.
    buffer_size : int or float, optional
        Size of the save buffer.
    total_size : int or float, optional
        Total size of the save. The following relationships between the 
        parameters must exist: 
            total_size % buffer_size == 0
            number of call to track / save_every == total_size
    mpi_mode : bool, optional
        If True, open the HDF5 file in parallel mode, which is needed to
        allow several cores to write in the same file at the same time.
        If False, open the HDF5 file in standard mode.

    Methods
    -------
    track(object_to_save)
        Save data
    """
    
    def __init__(self, cavity_name, file_name=None, save_every=5,
                 buffer_size=500, total_size=2e4, mpi_mode=True):
        
        self.cavity_name = cavity_name
        group_name = cavity_name
        dict_buffer = {"cavity_voltage":(buffer_size,), 
                       "cavity_phase":(buffer_size,)}
        dict_file = {"cavity_voltage":(total_size,), 
                     "cavity_phase":(total_size,)}

        self.monitor_init(group_name, save_every, buffer_size, total_size,
                          dict_buffer, dict_file, file_name, mpi_mode)
        
        self.dict_buffer = dict_buffer
        self.dict_file = dict_file
                    
    def track(self, cavity):
        """
        Save data
        
        Parameters
        ----------
        cavity : CavityResonator object
        """        
        if self.track_count % self.save_every == 0:
            if isinstance(cavity, CavityResonator):
                self.to_buffer(cavity)
            else:                            
                raise TypeError("cavity should be a CavityResonator object.")
        self.track_count += 1