"""
Python module for extracting attribute from Arhive Extractor Device.
"""
import logging
import datetime
import numpy as np
import PyTango as tango
import pandas as pd
import traceback

__version__ = "1.0.1"

##########################################################################
###                 Install logger for the module                      ###
##########################################################################
logger = logging.getLogger(__name__)
#logger.setLevel(getattr(logging, logger.upper()))

if not logger.hasHandlers():
    # No handlers, create one
    sh = logging.StreamHandler()
    sh.setLevel(logger.level)
    sh.setFormatter(logging.Formatter("%(levelname)s:%(message)s"))
    logger.addHandler(sh)


##########################################################################
###               Commodity private variables                          ###
##########################################################################

# Extractor date format for GetAttDataBetweenDates
_DBDFMT = "%Y-%m-%d %H:%M:%S"

# Extractor date format for GetNearestValue
_DBDFMT2 = "%d-%m-%Y %H:%M:%S"

##########################################################################
###               Commodity private functions                          ###
##########################################################################

# Vectorized fromtimestamp function
# NOTE: it is faster than using pandas.to_datetime()
_ArrayTimeStampToDatetime = np.vectorize(datetime.datetime.fromtimestamp)

def _check_initialized():
    """
    Check if the module is initialized.

    Returns
    -------
    success : boolean
    """
    global _extractors
    if None in _extractors:
        logger.error("Module {0} is not initialied. You should run {0}.init().".format(__name__))
        return False
    return True

##----------------------------------------------------------------------##
def _dateparse(datestr):
    """
    Convenient function to parse date or duration strings.
    Exact date format is %Y-%m-%d-%H:%M:%S and it can be reduced to be less precise.
    Duration format is 'Xu' where X is a number and u is a unit in ('m':minutes, 'h':hours, 'd':days, 'M':months)
    If datstr is None, take the actual date and time.

    Parameters
    ---------
    datestr : string
        Date as a string, format %Y-%m-%d-%H:%M:%S or less precise.
        Duration as a string, format 'Xu' where X is a number and u is a unit in ('m':minutes, 'h':hours, 'd':days, 'M':months)

    Exceptions
    ----------
    ValueError
        If the parsing failed.

    Returns
    -------
    date : datetime.datetime or datetime.timedelta
        Parsed date or duration
    """
    logger.debug("Parsing date string '%s'"%datestr)

    # Determine date/duration by looking at the last char
    if datestr[-1] in "mhdM":
        # Duration
        logger.debug("Assuming a duration")

        try:
            q=float(datestr[:-1])
        except ValueError as e:
            logger.error("Failed to parse date string. Given the last character, a duration was assumed.")
            raise Exception("Could not parse argument to a date") from e

        # Convert all in minutes
        minutes = q*{'m':1, 'h':60, 'd':60*24, 'm':30*60*24}[datestr[-1]]

        return datetime.timedelta(minutes=minutes)

    else:
        # Probably a date string

        # This gives all format that will be tried, in order.
        # Stop on first parse success. Raise error if none succeed.
        fmt = [
            "%Y-%m-%d-%H:%M:%S",
            "%Y-%m-%d-%H:%M",
            "%Y-%m-%d-%H",
            "%Y-%m-%d",
            "%Y-%m",
            ]

        date = None
        for f in fmt:
            try:
                date = datetime.datetime.strptime(datestr, f)
            except ValueError:
                continue
            else:
                break
        else:
            raise ValueError("Could not parse argument to a date")

        return date

##----------------------------------------------------------------------##
def _check_attribute(attribute, db):
    """
    Check that the attribute is in the database

    Parameters
    ----------
    attribute : String
        Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

    db: str
        Which database to look in, 'H' or 'T'.
    """
    global _extractors

    logger.debug("Check that %s is archived."%attribute)
    if not _extractors[{'H':0, 'T':1}[db]].IsArchived(attribute):
        logger.error("Attribute '%s' is not archived in DB %s"%(attribute, _extractors[{'H':0, 'T':1}[db]]))
        raise ValueError("Attribute '%s' is not archived in DB %s"%(attribute, _extractors[{'H':0, 'T':1}[db]]))

##----------------------------------------------------------------------##
def _chunkerize(attribute, dateStart, dateStop, db, Nmax=100000):
    """

    Parameters
    ----------
    attribute : String
        Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

    dateStart : datetime.datetime
        Start date for extraction.

    dateStop : datetime.datetime
        Stop date for extraction.

    db: str
        Which database to look in, 'H' or 'T'.

    Nmax: int
        Max number of atoms in one chunk. Default 100000.

    Returns
    -------
    cdates : list
        List of datetime giving the limit of each chunks.
        For N chunks, there is N+1 elements in cdates, as the start and end boundaries are included.
    """
    info=infoattr(attribute, db=db)
    logger.debug("Attribute information \n%s"%info)

    # Get the number of points
    N=_extractors[{'H':0, 'T':1}[db]].GetAttDataBetweenDatesCount([
            attribute,
            dateStart.strftime(_DBDFMT2),
            dateStop.strftime(_DBDFMT2)
            ])
    logger.debug("On the period, there is %d entries"%N)

    dx=int(info["max_dim_x"])
    if dx > 1:
        logger.debug("Attribute is a vector with max dimension = %s"%dx)
        N=N*dx

    # If data chunk is too much, we need to cut it
    if N > Nmax:
        dt = (dateStop-dateStart)/(N//Nmax)
        cdates = [dateStart]
        while cdates[-1] < dateStop:
            cdates.append(cdates[-1]+dt)
        cdates[-1] = dateStop
        logger.debug("Cutting access to %d little chunks of time, %s each."%(len(cdates)-1, dt))
    else:
        cdates=[dateStart, dateStop]

    return cdates

##----------------------------------------------------------------------##
def _cmd_with_retry(dp, cmd, arg, retry=2):
    """
    Run a command on tango.DeviceProxy, retrying on DevFailed.

    Parameters
    ----------
    dp: tango.DeviceProxy
        Device proxy to try command onto.

    cmd : str
        Command to executte on the extractor

    arg : list
        Attribute to pass to the command

    retry : int
        Number of command retry on DevFailed

    Returns
    -------
    cmdreturn :
        Whatever the command returns.
        None if failed after the amount of retries.
    """
    logger.info("Perform Command {} {}".format(cmd, arg))

    for i in range(retry):
        # Make retrieval request
        logger.debug("Execute %s (%s)"%(cmd, arg))
        try:
            cmdreturn = getattr(dp, cmd)(arg)
        except tango.DevFailed as e:
            logger.warning("The extractor device returned the following error:")
            logger.warning(e)
            if  i == retry-1:
                logger.error("Could not execute command %s (%s). Check the device extractor"%(cmd, arg))
                return None
            logger.warning("Retrying...")
            continue
        break
    return cmdreturn

##########################################################################
###                  Module private variables                          ###
##########################################################################
# Tuple of extractor for HDB and TDB
_extractors = (None, None)

# Tuple for attribute tables
_AttrTables = (None, None)

##########################################################################
###                Module initialisation functions                     ###
##########################################################################

def init(
        HdbExtractorPath="archiving/hdbextractor/2",
        TdbExtractorPath="archiving/tdbextractor/2",
        loglevel="info",
            ):
    """
    Initialize the module.
    Instanciate tango.DeviceProxy for extractors (TDB and HDB)

    Parameters:
    -----------
    HdbExtractorPath, TdbExtractorPath: string
        Tango path to the extractors.

    loglevel: string
        loglevel to pass to logging.Logger
    """
    global _extractors
    global _AttrTables

    try:
        logger.setLevel(getattr(logging, loglevel.upper()))
    except AttributeError:
        logger.error("Wrong log level specified: {}".format(loglevel.upper()))

    logger.debug("Instanciating extractors device proxy...")

    _extractors = (tango.DeviceProxy(HdbExtractorPath), tango.DeviceProxy(TdbExtractorPath))
    logger.debug("{} and {} instanciated.".format(*_extractors))

    logger.debug("Configuring extractors device proxy...")
    for e in _extractors:
        # set timeout to 3 sec
        e.set_timeout_millis(3000)

    logger.debug("Filling attributes lookup tables...")
    _AttrTables = tuple(e.getattnameall() for e in _extractors)
    logger.debug("HDB: {} TDB: {} attributes counted".format(len(_AttrTables[0]), len(_AttrTables[1])))

##########################################################################
###                    Module access functions                         ###
##########################################################################

def extract(
        attr,
        date1, date2=None,
        method="nearest",
        db='H',
        ):
    """
    Access function to perform extraction between date1 and date2.
    Can extract one or several attributes.
    date1 and date2 can be both exact date, or one of two can be a time interval that will be taken relative to the other.


    Parameters:
    -----------
    attr: string, list, dict
        Attribute(s) to extract.
        If string, extract the given attribute, returning a pandas.Series.
        If list, extract attributes and return a list of pandas.Series.
        If a dict, extract attributes and return a dict of pandas.Series with same keys.

    date1, date2: string, datetime.datetime, datetime.timedelta, None
        Exact date, or duration relative to date2.
        If string, it will be parsed.
        A start date can be given with string format '%Y-%m-%d-%H:%M:%S' or less precise (ie '2021-02', '2022-11-03' '2022-05-10-21:00'.i..).
        A duration can be given with string format 'Xu' where X is a number and u is a unit in ('m':minutes, 'h':hours, 'd':days, 'M':months)
        A datetime.datetime object or datetime.timedelta object will be used as is.
        date2 can be None. In that case it is replaced by the current time.

    method: str
        Method of extraction
            'nearest': Retrieve nearest value of date1, date2 is ignored.
            'between': Retrive data between date1 and date2.

    db: str
        Which database to look in, 'H' or 'T'.

    """

    ## _-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_
    #    Perform a few sanity checks
    if not _check_initialized():
        # Stop here, the function has produced a message if necessary
        return

    if not db in ("H", "T"):
        raise ValueError("Attribute 'db' should be 'H' or 'T'")


    allowedmethods=("nearest", "between", "minmaxmean")
    if not method in allowedmethods:
        raise ValueError("Attribute 'method' should be in {}".format(str(allowedmethods)))

    ## _-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_
    #     Work with dates
    if not type(date1) in (datetime.datetime, datetime.timedelta):
        date1 = _dateparse(date1)
    if date2 is None:
        date2 = datetime.datetime.now()
    else:
        if not type(date2) in (datetime.datetime, datetime.timedelta):
            date2 = _dateparse(date2)

    if not datetime.datetime in (type(date1), type(date2)):
        logger.error("One of date1 date2 should be an exact date.\nGot {} {}".format(date1, date2))
        raise ValueError("date1 and date2 not valid")

    # Use timedelta relative to the other date. date1 is always before date2
    if type(date1) is datetime.timedelta:
        date1 = date2-date1
    if type(date2) is datetime.timedelta:
        date2 = date1+date2

    if  date1 > date2:
        logger.error("date1 must precede date2.\nGot {} {}".format(date1, date2))
        raise ValueError("date1 and date2 not valid")

    ## _-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_-_
    #      Perform extraction and return

    if type(attr) is dict:
        d=dict()
        for k,v in attr.items():
            try:
                d.update({k:_extract_attribute(v, method, date1, date2, db)})
            except Exception as e:
                logger.debug("Exception in _extract_attribute(): "+str(e))
                logger.debug(traceback.print_tb(e.__traceback__))
                logger.error("Could not extract {}.".format(v))

        return d

    if type(attr) in (list,tuple):
        d=[]
        for v in attr:
            try:
                d.append(_extract_attribute(v, method, date1, date2, db))
            except Exception as e:
                logger.debug("Exception in _extract_attribute(): "+str(e))
                logger.debug(traceback.print_tb(e.__traceback__))
                logger.error("Could not extract {}.".format(v))

        return d

    try:
        d=_extract_attribute(attr, method, date1, date2, db)
    except Exception as e:
        logger.debug("Exception in _extract_attribute(): "+str(e))
        logger.debug(traceback.print_tb(e.__traceback__))
        logger.error("Could not extract {}.".format(attr))
        return None

    return d


##----------------------------------------------------------------------##
def findattr(pattern, db="H"):
    """
    Search for an attribute path using the pattern given.
    Case insensitive.

    Parameters:
    -----------
    pattern: str
        Pattern to search, wildchar * accepted.
        example "dg*dcct*current"

    db: str
        Which database to look in, 'H' or 'T'.

    Returns:
    --------
    results: (str,)
        List of string match
    """
    if not _check_initialized():
        return

    if not db in ("H", "T"):
        raise AttributeError("Attribute db should be 'H' or 'T'")

    global _AttrTables

    keywords=pattern.lower().split('*')

    # Select DB
    attr_table = _AttrTables[{'H':0, 'T':1}[db]]

    matches = [attr for attr in attr_table if all(k in attr.lower() for k in keywords)]

    return matches

##----------------------------------------------------------------------##
def infoattr(attribute, db='H'):
    """
    Get informations for an attribute and pack it into a python dict.

    Parameters
    ----------
    attribute : String
        Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

    db: str
        Which database to look in, 'H' or 'T'.

    Returns
    -------
    info : dict
        Dictionnary of propertyname:propertyvalue
    """
    if not _check_initialized():
        return

    if not db in ("H", "T"):
        raise AttributeError("Attribute db should be 'H' or 'T'")

    info = dict()

    for func in ("GetAttDefinitionData", "GetAttPropertiesData"):
        R=getattr(_extractors[{'H':0, 'T':1}[db]], func)(attribute)
        if not R is None:
            for i in R:
                _s=i.split("::")
                info[_s[0]]=_s[1]
        else:
            logger.warning("Function %s on extractor returned None"%func)

    return info

##########################################################################
###                    Module core functions                           ###
##########################################################################

def _extract_attribute(attribute, method, date1, date2, db):
    """
    Check if exists, check scalar or spectrum and dispatch
    """

    # Uncapitalize attribute
    attribute = attribute.lower()
    _check_attribute(attribute, db)

    # Get info about the attribute
    info=infoattr(attribute, db=db)
    logger.debug("Attribute information \n%s"%info)

    # Detect spectrum
    attrtype="scalar"
    if int(info["max_dim_x"]) > 1:
        if int(info["max_dim_y"]) > 0:
            logger.warning("Attribute %s is a (%s; %s) vector. This is poorly handled by this module."%(
                attribute, info["max_dim_x"], info["max_dim_y"]))
            attrtype="multi"
        else:
            logger.info("Attribute %s is a 1D vector, dimension = %s."%(
                attribute, info["max_dim_x"]))
            attrtype="vector"

    # =============
    # For now we handle multi dimension the same way as scalar, which will get only the first element
    if (attrtype=="scalar") or (attrtype=="multi"):
        return _extract_scalar(attribute, method, date1, date2, db)
    if attrtype=="vector":
        return _extract_vector(attribute, method, date1, date2, db)


##---------------------------------------------------------------------------##
def _extract_scalar(attribute, method, date1, date2, db):

    # =====================
    if method == "nearest":
        cmdreturn = _cmd_with_retry(_extractors[{'H':0, 'T':1}[db]], "GetNearestValue", [
                                                attribute,
                                                date1.strftime(_DBDFMT),
                                                ])

        # Unpack return
        try:
            _date, _value = cmdreturn
        except TypeError:
            logger.error("Could not extract this chunk. Check the device extractor")
            return None

        # Fabricate return pandas.Series
        d=pd.Series(index=[datetime.datetime.fromtimestamp(_date),], data=[_data,], name=attribute)

        return d

    # =====================
    if method == "between":
        # Cut the time horizon in chunks
        cdates = _chunkerize(attribute, dateStart, dateStop, db)

        # Array to hold data
        data = []

        # For each date chunk
        for i_d in range(len(cdates)-1):
            cmdreturn = _cmd_with_retry(_extractors[{'H':0, 'T':1}[db]], "ExtractBetweenDates", [
                                                    attribute,
                                                    cdates[i_d].strftime(_DBDFMT),
                                                    cdates[i_d+1].strftime(_DBDFMT)
                                                    ])


            # Unpack return
            try:
                _date, _value = cmdreturn
            except TypeError:
                logger.error("Could not extract this chunk. Check the device extractor")
                return None


            # Transform to datetime - value arrays
            _value = np.asarray(_value, dtype=float)
            if len(_date) > 0:
                _date = _ArrayTimeStampToDatetime(_date/1000.0)

            # Fabricate return pandas.Series
            data.append(pd.Series(index=_date, data=_data,name=attribute))

        # Concatenate chunks
        return pd.concat(data)

    # ========================
    if method == "minmaxmean":

        # If we are here, the method is not implemented
        logger.error("Method {} is not implemented for scalars.".format(method))
        raise NotImplemented

##---------------------------------------------------------------------------##
def _extract_vector(attribute, method, date1, date2, db):

    # Get info about the attribute
    info=infoattr(attribute, db=db)

    # =====================
    if method == "nearest":
        # Get nearest does not work with vector.
        # Make a between date with surounding dates.

        # Dynamically find surounding
        cnt=0
        dt=datetime.timedelta(seconds=10)
        while cnt<1:
            logger.debug("Seeking points in {} to {}".format(date1-dt,date1+dt))
            cnt=_extractors[{'H':0, 'T':1}[db]].GetAttDataBetweenDatesCount([
                    attribute,
                    (date1-dt).strftime(_DBDFMT2),
                    (date1+dt).strftime(_DBDFMT2)
                    ])
            dt=dt*1.5
        logger.debug("Found {} points in a +- {} interval".format(cnt,str(dt/1.5)))


        # For vector, we have to use the GetAttxxx commands
        cmdreturn = _cmd_with_retry(_extractors[{'H':0, 'T':1}[db]], "GetAttDataBetweenDates", [
                                                attribute,
                                                (date1-dt).strftime(_DBDFMT),
                                                (date1+dt).strftime(_DBDFMT),
                                                ])

        # Unpack return
        try:
            [N,], [name,] = cmdreturn
            N=int(N)
        except TypeError:
            logger.error("Could not extract this attribute. Check the device extractor")
            return None

        # Read the history
        logger.debug("Retrieve history of %d values. Dynamic attribute named %s."%(N, name))
        attrHist = _extractors[{'H':0, 'T':1}[db]].attribute_history(name, N)

        # Transform to datetime - value arrays
        _value = np.empty((N, int(info["max_dim_x"])), dtype=float)
        _value[:] = np.nan
        _date = np.empty(N, dtype=object)
        for i_h in range(N):
            _value[i_h,:attrHist[i_h].dim_x]=attrHist[i_h].value
            _date[i_h]=attrHist[i_h].time.todatetime()

        # Seeking nearest entry
        idx=np.argmin(abs(_date-date1))
        logger.debug("Found nearest value at index {}: {}".format(idx, _date[idx]))

        # Fabricate return pandas.Series
        d=pd.Series(index=[_date[idx],], data=[_value[idx],], name=attribute)

        return d

    # If we are here, the method is not implemented
    logger.error("Method {} is not implemented for vectors.".format(method))
    raise NotImplemented


##---------------------------------------------------------------------------##
def ExtrBetweenDates_MinMaxMean(
        attribute,
        dateStart,
        dateStop=None,
        timeInterval=datetime.timedelta(seconds=60),
        db='H',
        ):
    """
    Query attribute data from an archiver database, get all points between dates.
    Use ExtractBetweenDates.

    Parameters
    ----------
    attribute : String
        Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

    dateStart : datetime.datetime, string
        Start date for extraction. If string, it will be parsed.
        Example of string format %Y-%m-%d-%H:%M:%S or less precise.

    dateStop : datetime.datetime, string
        Stop date for extraction. If string, it will be parsed.
        Example of string format %Y-%m-%d-%H:%M:%S or less precise.
        Default is now (datetime.datetime.now())

    timeInterval: datetime.timedelta, string
        Time interval used to perform min,max and mean.
        Can be a string with a number and a unit in "d", "h", "m" or "s"

    db: str
        Which database to look in, 'H' or 'T'.

    Exceptions
    ----------
    ValueError
        The attribute is not found in the database.

    Returns
    -------
    [mdates, value_min, value_max, value_mean] : array
        mdates : numpy.ndarray of datetime.datime objects
            Dates of the values, middle of timeInterval windows
        value_min : numpy.ndarray
            Minimum of the value on the interval
        value_max : numpy.ndarray
            Maximum of the value on the interval
        value_mean : numpy.ndarray
            Mean of the value on the interval

    """
    if not _check_initialized():
        return

    if not db in ("H", "T"):
        raise AttributeError("Attribute db should be 'H' or 'T'")

    # Uncapitalize attribute
    attribute = attribute.lower()

    # Check attribute is in database
    _check_attribute(attribute, db=db)

    # Parse dates
    dateStart = _dateparse(dateStart)
    dateStop = _dateparse(dateStop)

    # Parse timeInterval if string
    if type(timeInterval) is str:
        try:
            mul = {'s':1, 'm':60, 'h':60*60, 'd':60*60*24}[timeInterval[-1]]
        except KeyError:
            logger.error("timeInterval could not be parsed")
            raise ValueError("timeInterval could not be parsed")
        timeInterval= datetime.timedelta(seconds=int(timeInterval[:-1])*mul)

    # Get info about the attribute
    info=infoattr(attribute)
    logger.debug("Attribute information \n%s"%info)

    # Detect spectrum
    attrtype="scalar"
    if int(info["max_dim_x"]) > 1:
        logger.error("Attribute is not a scalar. Cannot perform this kind of operation.")
        return None

    # Cut data range in time chunks
    cdates = [dateStart]
    while cdates[-1] < dateStop:
        cdates.append(cdates[-1]+timeInterval)
    cdates[-1] = dateStop
    mdates = np.asarray(cdates[:-1])+timeInterval/2
    logger.debug("Cutting time range to %d chunks of time, %s each."%(len(cdates)-1, timeInterval))

    # Prepare arrays
    value_min = np.empty(len(cdates)-1)
    value_max = np.empty(len(cdates)-1)
    value_mean = np.empty(len(cdates)-1)

    # For each time chunk
    for i_d in range(len(cdates)-1):
        for func, arr in zip(
                ["Max", "Min", "Avg"],
                [value_max, value_min, value_mean],
                ):
            # Make requests
            logger.debug("Perform GetAttData%sBetweenDates (%s, %s, %s)"%(
                func,
                attribute,
                cdates[i_d].strftime(_DBDFMT2),
                cdates[i_d+1].strftime(_DBDFMT2))
                )

            _val =getattr(_extractors[{'H':0, 'T':1}[db]], "GetAttData%sBetweenDates"%func)([
                attribute,
                cdates[i_d].strftime(_DBDFMT2),
                cdates[i_d+1].strftime(_DBDFMT2)
                ])

            arr[i_d] = _val

    logger.debug("Extraction done for %s."%attribute)
    return pd.DataFrame(
            index=mdates,
            data={
                "Min":value_min,
                "Mean":value_mean,
                "Max":value_max,
                },)

## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ##
## Initialize on import
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ##
init()