# -*- coding: utf-8 -*-
"""
This module defines utilities functions, helping to deals with tracking output 
and hdf5 files.

@author: Alexis Gamelin
"""

import h5py as hp

def merge_files(files_prefix, files_number, file_name=None):
    """
    Merge several hdf5 files into one.
    
    The function assumes that the files to merge have names in the follwing 
    format:
        - "files_prefix_0.hdf5"
        - "files_prefix_1.hdf5"
        ...
        - "files_prefix_files_number.hdf5"

    Parameters
    ----------
    files_prefix : str
        Name of the files to merge.
    files_number : int
        Number of files to merge.
    file_name : str, optional
        Name of the file with the merged data. If None, files_prefix without
        number is used.

    """
    if file_name == None:
        file_name = files_prefix
    f = hp.File(file_name + ".hdf5", "a")
    
    ## Create file architecture
    f0 = hp.File(files_prefix + "_" + str(0) + ".hdf5", "r")
    for group in list(f0):
        f.require_group(group)
        for dataset_name in list(f0[group]):
            shape = f0[group][dataset_name].shape
            dtype = f0[group][dataset_name].dtype
            shape_needed = list(shape)
            shape_needed[-1] = shape_needed[-1]*files_number
            shape_needed = tuple(shape_needed)
            f[group].create_dataset(dataset_name, shape_needed, dtype)
            
    f0.close()
    
    ## Copy data
    for i in range(files_number):
        fi = hp.File(files_prefix + "_" + str(i) + ".hdf5", "r")
        for group in list(fi):
            for dataset_name in list(fi[group]):
                shape = fi[group][dataset_name].shape
                n_slice = int(len(shape) - 1)
                length = shape[-1]
                slice_list = []
                for n in range(n_slice):
                    slice_list.append(slice(None))
                slice_list.append(slice(length*i,length*(i+1)))
                if (dataset_name == "time") and (i != 0):
                    f[group][dataset_name][tuple(slice_list)] = f[group][dataset_name][(length*i) - 1] + fi[group][dataset_name]
                else:
                    f[group][dataset_name][tuple(slice_list)] = fi[group][dataset_name]
        fi.close()
    f.close()