Skip to content
Snippets Groups Projects
SoleilTools.py 11.8 KiB
Newer Older
BRONES Romain's avatar
BRONES Romain committed
# -*- coding: utf-8 -*-
"""
Tools for Soleil Synchrotron

@author: broucquart
"""

import numpy as np
import logging
import datetime
import matplotlib.colors as mcol
import pickle
import matplotlib

logger=logging.getLogger(__name__)

###############################################################################
# VECTORIZED DATE FUNCTIONS
###############################################################################
ArrayTimeStampToDatetime = np.vectorize(datetime.datetime.fromtimestamp)
ArrayDatetimeToTimeStamp = np.vectorize(datetime.datetime.timestamp)
ArrayStrpToDateTime = np.vectorize(lambda x : datetime.datetime.strptime(x, "%Y/%m/%d %H:%M:%S.%f"))


###############################################################################
# DATA IMPORTATION
###############################################################################

##---------------------------------------------------------------------------##
def load_filer_trend(filename, delimiter='\t'):
    """
    Load data from a file generated by atkfilertrend.
    
    Delimiter must be comma ','.

    Parameters
    ----------
    filename : String
        Path to the file to load.

    Returns
    -------
    ddata : dict
        Dictionary of data. Key is the attribute tango path, data is the numpy 
        array of data.
        The special key "Time" hold the timestamps.

    """
    
    # Load the file data
    logger.info("Load file %s"%filename)
    data = np.genfromtxt(filename, skip_header=1, skip_footer=1, delimiter=delimiter).transpose()
    logger.debug("data shape : %s"%str(data.shape))
    
    # Read the first line and parse attribute names
    with open(filename, 'r') as fp:
        head = fp.readline()
    
    # Split head
    logger.debug("read head : %s"%head)    
    head = head.split(delimiter)
    logger.debug("parsed head : %s"%str(head))
    
    # Create the dictionnary
    # Convert microsecond to seconds
    # Convert timestamps to datetime
    ddata = {"Time":ArrayTimeStampToDatetime(data[0]/1000)}
    
    # Attach data to key in dict.
    for i in range(1, len(head)-1):
        ddata[head[i]] = data[i]
    
    return ddata

##---------------------------------------------------------------------------##
def load_mambo_file(filename):
    """
    Load data from a file extracted from Mambo.

    Parameters
    ----------
    filename : string
        Filepath.

    Returns
    -------
    ddata : dict
        Dictionary of data. Key is the attribute tango path, data is a tuple of
        two numpy arrays. First array is datetime values, second is attribute
        value.

    """
    # Load the file data as string
    logger.info("Load file %s"%filename)
    data = np.genfromtxt(filename, delimiter='\t', skip_header=1, dtype=str).transpose()
    logger.debug("data shape : %s"%str(data.shape))

    # Read the first line and parse attribute names
    with open(filename, 'r') as fp:
        head = fp.readline()
    
    # Split head, remove last char (newline)
    logger.debug("read head : %s"%head)    
    head = head[:-1].split('\t')
    logger.debug("parsed head : %s"%str(head))

    # Convert string to datetime
    tdata = ArrayStrpToDateTime(data[0])
    
    ddata = dict()
    # Find correct values for each dataset (ignore "*")
    # Add to dictionnary, key is the attribute tango path, value is tuple of
    # time array and value array
    for n in range(1, len(data)):
        good=np.where(data[n]!="*")[0]
        ddata[head[n]] = (tdata[good], data[n][good].astype(np.float))

    return ddata

###############################################################################
# SIGNAL PROCESSING
###############################################################################

##---------------------------------------------------------------------------##
def MM(datax, datay, N, DEC=1):
    """
    Mobile Mean along x. Averaging window of N points.

    Parameters
    ----------
    datax : numpy.ndarray
        X axis, will only be cut at edge to match the length of mean Y.
        Set to "None" if no X-axis
    datay : numpy.ndarray
        Y axis, will be averaged.
    N : int
        Averaging window length in points.

    Returns
    -------
    Tuple of numpy.ndarray
        (X axis, Y axis) averaged data.

    """
    if datax is None:
        return (np.arange(N//2, len(datay)-N//2+1)[::DEC],
            np.convolve(datay, np.ones(N)/N, mode='valid')[::DEC])
    
    return (np.asarray(datax[N//2:-N//2+1])[::DEC],
            np.convolve(datay, np.ones(N)/N, mode='valid')[::DEC])


##---------------------------------------------------------------------------##
def meanstdmaxmin(x, y, N):
    """
    Compute mean, max, min and +- std over block of N points on the Y axis.
    Return arrays on length len(x)//N points.

    Parameters
    ----------
    x : numpy.ndarray
        X vector, i.e sampling times.
    y : numpy.ndarray
        Y vector, i.e. values.
    N : int
        Number on points to average.

    Returns
    -------
    xmean : numpy.ndarray
        New x vector.
    ymean : numpy.ndarray
        Means of Y.
    ystd : numpy.ndarray
        Std of Y.
    ymax : numpy.ndarray
        Maxes of Y.
    ymin : numpy.ndarray
        Mins of Y..

    """
    # If x vector is datetime, convert to timestamps
    if type(x[0]) is datetime.datetime:
        xIsDatetime=True
        x = ArrayDatetimeToTimeStamp(x)
    else:
        xIsDatetime=False

    # Quick verification on the X data vector jitter.
    period = np.mean(x[1:]-x[:-1])
    jitter = np.std(x[1:]-x[:-1])
    if jitter > 0.01*period:
        logger.warning("On X data vector : sampling jitter is over 1%% of the period. (j=%.3g, p=%.3g)"%(jitter, period))
    
    # Get number of block of N points
    _L=len(y)//N
        
    
    # Reshape the arrays.
    # Drop last points that does not fill a block of N points.
    _x=np.reshape(x[:_L*N], (_L, N))
    _y=np.reshape(y[:_L*N], (_L, N))

    # compute the new x vector.
    # Use mean to compute new absciss position
    xmean = np.mean(_x, axis=1)
    
    if xIsDatetime:
        xmean = ArrayTimeStampToDatetime(xmean)
    
    # Compute parameters
    ymean = np.mean(_y, axis=1)
    ystd = np.std(_y, axis=1)
    ymin = np.min(_y, axis=1)
    ymax = np.max(_y, axis=1)
   
    return (xmean, ymean, ystd, ymax, ymin)
    
###############################################################################
## PLOTTING
###############################################################################

##---------------------------------------------------------------------------##
def plot_meanstdmaxmin(ax, datax, datay, N,
                       c=None, label=None):
    """
    Plot on a ax the representation in mean, +- std and min max.

    Parameters
    ----------
    ax : matplotlib.axes._base._AxesBase
        Ax on wich to plot.
    datax : numpy.ndarray
        X axis.
    datay : numpy.ndarray
        Y axis.
    N : int
        Number on points to average.
    c : TYPE, optional
        Color. The default is None.
    label : TYPE, optional
        Label. The default is None.

    Returns
    -------
    lines : TYPE
        DESCRIPTION.

    """
    
    # For the first  plot, consider the whole data range.
    # Compute the averaging ratio. Minimum ratio is 1
    ratio = max(len(datax)//N, 1)

    # Compute new data
    xmean, ymean, ystd, ymax, ymin = meanstdmaxmin(datax, datay, ratio)
    
    lines=[]
    # First, plot the mean with the given attributes
    lines.append(ax.plot(xmean, ymean, color=c, label=label)[0])
    
    # Retrieve the color, usefull if c was None
    c=lines[0].get_color()
    
    # Add max, min and std area
    lines.append(ax.plot(xmean, ymax, linestyle='-', color=mcol.to_rgba(c, 0.5))[0])
    lines.append(ax.plot(xmean, ymin, linestyle='-', color=mcol.to_rgba(c, 0.5))[0])
    lines.append(ax.fill_between(xmean, ymean-ystd, ymean+ystd, color=mcol.to_rgba(c, 0.1)))
BRONES Romain's avatar
BRONES Romain committed
    
    return lines

##---------------------------------------------------------------------------##
def plot_MM(ax, datax, datay, N, DEC=1,
            c=None, label=None):
    """
    Plot a signal with its mobile mean. The signal is plotted with transparency.

    Parameters
    ----------
    ax : matplotlib.axes._base._AxesBase
        Axe on which to plot.
    datax : numpy.ndarray, None
        X axis data.
    datay : numpy.ndarray
        Y axis data.
    N : int
        Averaging window length in points.
    c : TYPE, optional
        Line color. The default is None.
    label : str, optional
        Line label. The default is None.

    Returns
    -------
    lines : TYPE
        DESCRIPTION.

    """
    # To collect lines
    lines=[]
    
    # Plot mobile mean
    _l=ax.plot(*MM(datax, datay, N, DEC), c=c, label=label)[0]
    lines.append(_l)
    
    # Retrieve the color, usefull if c was None
    c=lines[0].get_color()
    
    # Plot entire signal    
    if datax is None:
        # Case no xaxis data
        _l=ax.plot(datay, c=mcol.to_rgba(c, 0.4))[0]
    else:
        _l=ax.plot(datax, datay, c=mcol.to_rgba(c, 0.4))[0]
        
    return lines

###############################################################################
## PLOT MANIPULATION
###############################################################################

##---------------------------------------------------------------------------##
def get_current_ax_zoom(ax):
    """
    Get the current ax zoom setup and print the python command to set it exactly.

    Parameters
    ----------
    ax : numpy.ndarray
        Array of ax.

    Raises
    ------
    NotImplementedError
        When the type is not implemented. It is time to implement it !

    Returns
    -------
    None.

    """
    if type(ax) is np.ndarray:
        for i in range(len(ax)):
            print("ax[%d].set_xlim"%i+str(ax[i].get_xlim()))
            print("ax[%d].set_ylim"%i+str(ax[i].get_ylim()))
        return
    
    raise NotImplementedError("Type is %s"%type(ax))

###############################################################################
## DATE PROCESSING
###############################################################################

##---------------------------------------------------------------------------##
def get_time_region(t, startDate, endDate):
    """
    Return a range of index selecting the ones between the start and stop date.

    Parameters
    ----------
    t : numpy.ndarray
        An array of datetime objects.
    startDate : datetime.datetime
        Start date.
    endDate : datetime.datetime
        Stop date.

    Returns
    -------
    zone : numpy.ndarray
        A numpy arange between both index.

    """
    iT1 = np.searchsorted(t, startDate)
    iT2 = np.searchsorted(t, endDate)
    zone = np.arange(iT1, iT2)
    if len(zone)==0:
        logging.warning("Time zone is empty.")
    return zone

###############################################################################
# DATA EXPORTATION
###############################################################################

##---------------------------------------------------------------------------##
def export_mpl(fig, filename):
    """
    Export figure to .mpl file.

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        Figure to export.
    filename : str
        Filename, without extension.

    Returns
    -------
    None.

    """
    if not type(fig) is matplotlib.figure.Figure:
        raise TypeError("Parameter fig should be a matplotlib figure (type matplotlib.figure.Figure).")
    with open(filename+".mpl", 'wb') as fp:
        pickle.dump(fig, fp)