"""
Python module for extracting attribute from Arhive Extractor Device.
"""
import logging
import datetime
import numpy as np
import PyTango as tango
import pandas as pd

__version__ = "1.0.1"

##########################################################################
""" Commodity variables """

# Extractor date format for GetAttDataBetweenDates
DBDFMT = "%Y-%m-%d %H:%M:%S"

# Extractor date format for GetNearestValue
DBDFMT2 = "%d-%m-%Y %H:%M:%S"

# Vectorized fromtimestamp function
ArrayTimeStampToDatetime = np.vectorize(datetime.datetime.fromtimestamp)


class ArchiveExtractor:

    ##########################################################################
    def __init__(
            self,
            extractorKind='H', extractorNumber=2,
            extractorPath=None,
            logger='info',
            ):
        """
        Constructor function

        Parameters
        ----------
        extractorKind: char
            Either 'H' or 'T' for HDB or TDB.

        extractorNumber: int
            Number of the archive extractor instance to use.

        extractorPath: string
            Tango path to the extractor.
            If this argument is given, it takes precedence over extractorKind and extractorNumber.

        logger: logging.Logger, str
            Logger object to use.
            If string, can be a log level. A basic logger with stream handler will be instanciated.
            Default to 'info'.

        Return
        ------
        ArchiveExtractor
        """

        #######################################################
        # Get logger
        if type(logger) == logging.Logger:
            self.logger = logger
        else:
            self.logger = logging.getLogger(__name__)
            self.logger.setLevel(getattr(logging, logger.upper()))
            if not self.logger.hasHandlers():
                # No handlers, create one
                sh = logging.StreamHandler()
                sh.setLevel(self.logger.level)
                sh.setFormatter(logging.Formatter("%(levelname)s:%(message)s"))
                self.logger.addHandler(sh)

        #######################################################
        # Select Extractor
        if extractorPath is None:
            self.extractor = tango.DeviceProxy(
                    "archiving/%sDBExtractor/%d"%(extractorKind, extractorNumber)
                    )
        else:
            self.extractor = tango.DeviceProxy(extractorPath)

        self.extractor.set_timeout_millis(3000)
        self.logger.debug("Archive Extractor %s used."%self.extractor.name())

        # Get the attribute table
        self.attr_table = self.extractor.getattnameall()

    ##---------------------------------------------------------------------------##
    @staticmethod
    def dateparse(datestr):
        """
        Convenient function to parse date strings.
        Global format is %Y-%m-%d-%H:%M:%S and it can be reduced to be less precise.

        Parameters
        ---------
        datestr : string
            Date as a string, format %Y-%m-%d-%H:%M:%S or less precise.

        Exceptions
        ----------
        ValueError
            If the parsing failed.

        Returns
        -------
        date : datetime.datetime
            Parsed date
        """

        # This gives all format that will be tried, in order.
        # Stop on first parse success. Raise error if none succeed.
        fmt = [
            "%Y-%m-%d-%H:%M:%S",
            "%Y-%m-%d-%H:%M",
            "%Y-%m-%d-%H",
            "%Y-%m-%d",
            "%Y-%m",
            ]

        date = None
        for f in fmt:
            try:
                date = datetime.datetime.strptime(datestr, f)
            except ValueError:
                continue
            else:
                break
        else:
            raise ValueError("Could not parse argument to a date")

        return date

    ##---------------------------------------------------------------------------##
    def betweenDates(
            self,
            attribute,
            dateStart,
            dateStop=None,
            ):
        """
        Query attribute data from an archiver database, get all points between dates.
        Use ExtractBetweenDates.

        Parameters
        ----------
        attribute : String
            Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

        dateStart : datetime.datetime, string
            Start date for extraction. If string, it will be parsed.
            Example of string format %Y-%m-%d-%H:%M:%S or less precise.

        dateStop : datetime.datetime, string, None
            Stop date for extraction.
            If string, it will be parsed.
            Example of string format %Y-%m-%d-%H:%M:%S or less precise.
            If None, it takes the current date and time.
            Default is None (now).

        Exceptions
        ----------
        ValueError
            The attribute is not found in the database.

        Returns
        -------
        [date, value] : array
            date : numpy.ndarray of datetime.datime objects
                Dates of the values
            value : numpy.ndarray
                Archived values

        """

        # Parse date if it is string
        if type(dateStart) is str:
            dateStart = self.dateparse(dateStart)
        if dateStop is None:
            dateStop = datetime.datetime.now()
        if type(dateStop) is str:
            dateStop = self.dateparse(dateStop)

        # Uncapitalize attribute
        attribute = attribute.lower()

        # Check attribute is in database
        self._check_attribute(attribute)

        # Get info about the attribute
        info=self.infoattr(attribute)
        self.logger.debug("Attribute information \n%s"%info)

        # Detect spectrum
        attrtype="scalar"
        if int(info["max_dim_x"]) > 1:
            if int(info["max_dim_y"]) > 0:
                self.logger.warning("Attribute %s is a (%s; %s) vector. This is poorly handled by this script."%(
                    attribute, info["max_dim_x"], info["max_dim_y"]))
                attrtype="multi"
            else:
                self.logger.info("Attribute %s is a 1D vector, dimension = %s."%(
                    attribute, info["max_dim_x"]))
                attrtype="vector"

        # Cut the time horizon in chunks
        cdates = self.chunkerize(attribute, dateStart, dateStop)

        # Arrays to hold every chunks
        value = []
        date = []

        # For each date chunk
        for i_d in range(len(cdates)-1):

            # =============
            # For now we handle multi dimension the same way as scalar, which will get only the first element
            if (attrtype=="scalar") or (attrtype=="multi"):
                # Inform on retrieval request
                self.logger.info("Perform ExtractBetweenDates (%s, %s, %s)"%(
                    attribute,
                    cdates[i_d].strftime(DBDFMT),
                    cdates[i_d+1].strftime(DBDFMT))
                    )

                cmdreturn = self._cmd_with_retry("ExtractBetweenDates", [
                                                        attribute,
                                                        cdates[i_d].strftime(DBDFMT),
                                                        cdates[i_d+1].strftime(DBDFMT)
                                                        ])

                # Check command return
                if cmdreturn is None:
                    logger.error("Could not extract this chunk. Check the device extractor")
                    return None

                # Unpack return
                _date, _value = cmdreturn

                # Transform to datetime - value arrays
                # NOTE: it is faster than using pandas.to_datetime()
                _value = np.asarray(_value, dtype=float)
                if len(_date) > 0:
                    _date = ArrayTimeStampToDatetime(_date/1000.0)

                value.append(_value)
                date.append(_date)

            # =============
            if attrtype=="vector":
                self.logger.info("Perform GetAttDataBetweenDates (%s, %s, %s)"%(
                                                        attribute,
                                                        cdates[i_d].strftime(DBDFMT),
                                                        cdates[i_d+1].strftime(DBDFMT)
                                                        ))

                [N,], [name,] = self.extractor.GetAttDataBetweenDates([
                    attribute,
                    cdates[i_d].strftime(DBDFMT),
                    cdates[i_d+1].strftime(DBDFMT)
                    ])
                N=int(N)

                # Read the history
                self.logger.debug("Retrieve history of %d values. Dynamic attribute named %s."%(N, name))
                attrHist = self.extractor.attribute_history(name, N)

                # Transform to datetime - value arrays
                _value = np.empty((N, int(info["max_dim_x"])), dtype=float)
                _value[:] = np.nan
                _date = np.empty(N, dtype=object)
                for i_h in range(N):
                    _value[i_h,:attrHist[i_h].dim_x]=attrHist[i_h].value
                    _date[i_h]=attrHist[i_h].time.todatetime()

                # Remove dynamic attribute
                self.logger.debug("Remove dynamic attribute %s."%name)
                self.extractor.RemoveDynamicAttribute(name)


                value.append(_value)
                date.append(_date)

        self.logger.debug("Concatenate chunks")
        value = np.concatenate(value)
        date = np.concatenate(date)

        self.logger.debug("Extraction done for %s."%attribute)
        if attrtype=="vector":
            return pd.DataFrame(index=date, data=value).dropna(axis=1, how='all')
        else:
            return pd.Series(index=date, data=value)


    ##---------------------------------------------------------------------------##
    def betweenDates_MinMaxMean(
            self,
            attribute,
            dateStart,
            dateStop=datetime.datetime.now(),
            timeInterval=datetime.timedelta(seconds=60),
            ):
        """
        Query attribute data from an archiver database, get all points between dates.
        Use ExtractBetweenDates.

        Parameters
        ----------
        attribute : String
            Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

        dateStart : datetime.datetime, string
            Start date for extraction. If string, it will be parsed.
            Example of string format %Y-%m-%d-%H:%M:%S or less precise.

        dateStop : datetime.datetime, string
            Stop date for extraction. If string, it will be parsed.
            Example of string format %Y-%m-%d-%H:%M:%S or less precise.
            Default is now (datetime.datetime.now())

        timeInterval: datetime.timedelta, string
            Time interval used to perform min,max and mean.
            Can be a string with a number and a unit in "d", "h", "m" or "s"

        Exceptions
        ----------
        ValueError
            The attribute is not found in the database.

        Returns
        -------
        [mdates, value_min, value_max, value_mean] : array
            mdates : numpy.ndarray of datetime.datime objects
                Dates of the values, middle of timeInterval windows
            value_min : numpy.ndarray
                Minimum of the value on the interval
            value_max : numpy.ndarray
                Maximum of the value on the interval
            value_mean : numpy.ndarray
                Mean of the value on the interval

        """

        # Parse date if it is string
        if type(dateStart) is str:
            dateStart = self.dateparse(dateStart)
        if type(dateStop) is str:
            dateStop = self.dateparse(dateStop)

        # Parse timeInterval if string
        if type(timeInterval) is str:
            try:
                mul = {'s':1, 'm':60, 'h':60*60, 'd':60*60*24}[timeInterval[-1]]
            except KeyError:
                self.logger.error("timeInterval could not be parsed")
                raise ValueError("timeInterval could not be parsed")
            timeInterval= datetime.timedelta(seconds=int(timeInterval[:-1])*mul)


        # Check that the attribute is in the database
        self.logger.debug("Check that %s is archived."%attribute)
        if not self.extractor.IsArchived(attribute):
            self.logger.error("Attribute '%s' is not archived in DB %s"%(attribute, extractor))
            raise ValueError("Attribute '%s' is not archived in DB %s"%(attribute, extractor))

        # Get info about the attribute
        info=self.infoattr(attribute)
        self.logger.debug("Attribute information \n%s"%info)

        # Detect spectrum
        attrtype="scalar"
        if int(info["max_dim_x"]) > 1:
            self.logger.error("Attribute is not a scalar. Cannot perform this kind of operation.")
            return None

        # Cut data range in time chunks
        cdates = [dateStart]
        while cdates[-1] < dateStop:
            cdates.append(cdates[-1]+timeInterval)
        cdates[-1] = dateStop
        mdates = np.asarray(cdates[:-1])+timeInterval/2
        self.logger.debug("Cutting time range to %d chunks of time, %s each."%(len(cdates)-1, timeInterval))

        # Prepare arrays
        value_min = np.empty(len(cdates)-1)
        value_max = np.empty(len(cdates)-1)
        value_mean = np.empty(len(cdates)-1)

        # For each time chunk
        for i_d in range(len(cdates)-1):
            for func, arr in zip(
                    ["Max", "Min", "Avg"],
                    [value_max, value_min, value_mean],
                    ):
                # Make requests
                self.logger.debug("Perform GetAttData%sBetweenDates (%s, %s, %s)"%(
                    func,
                    attribute,
                    cdates[i_d].strftime(DBDFMT2),
                    cdates[i_d+1].strftime(DBDFMT2))
                    )

                _val =getattr(self.extractor, "GetAttData%sBetweenDates"%func)([
                    attribute,
                    cdates[i_d].strftime(DBDFMT2),
                    cdates[i_d+1].strftime(DBDFMT2)
                    ])

                arr[i_d] = _val

        self.logger.debug("Extraction done for %s."%attribute)
        return [mdates, value_min, value_max, value_mean]

    def _check_attribute(self, attribute):
        """
        Check that the attribute is in the database

        Parameters
        ----------
        attribute : String
            Name of the attribute. Full Tango name i.e. "test/dg/panda/current".
        """
        self.logger.debug("Check that %s is archived."%attribute)
        if not self.extractor.IsArchived(attribute):
            self.logger.error("Attribute '%s' is not archived in DB %s"%(attribute, self.extractor))
            raise ValueError("Attribute '%s' is not archived in DB %s"%(attribute, self.extractor))

    def _cmd_with_retry(self, cmd, arg, retry=2):
        """
        Run a command on extractor tango proxy, retrying on DevFailed.

        Parameters
        ----------
        cmd : str
            Command to executte on the extractor

        arg : list
            Attribute to pass to the command

        retry : int
            Number of command retry on DevFailed

        Returns
        -------
        cmdreturn :
            Whatever the command returns.
            None if failed after the amount of retries.
        """

        for i in range(retry):
            # Make retrieval request
            self.logger.debug("Execute %s (%s)"%(cmd, arg))
            try:
                cmdreturn = getattr(self.extractor, cmd)(arg)
            except tango.DevFailed as e:
                self.logger.warning("The extractor device returned the following error:")
                self.logger.warning(e)
                if  i == retry-1:
                    logger.error("Could not execute command %s (%s). Check the device extractor"%(cmd, arg))
                    return None
                self.logger.warning("Retrying...")
                continue
            break
        return cmdreturn


    def chunkerize(self, attribute, dateStart, dateStop, Nmax=100000):
        """

        Parameters
        ----------
        attribute : String
            Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

        dateStart : datetime.datetime
            Start date for extraction.

        dateStop : datetime.datetime
            Stop date for extraction.

        Returns
        -------
        cdates : list
            List of datetime giving the limit of each chunks.
            For N chunks, there is N+1 elements in cdates, as the start and end boundaries are included.
        """
        info=self.infoattr(attribute)
        self.logger.debug("Attribute information \n%s"%info)

        # Get the number of points
        N=self.extractor.GetAttDataBetweenDatesCount([
                attribute,
                dateStart.strftime(DBDFMT2),
                dateStop.strftime(DBDFMT2)
                ])
        self.logger.debug("On the period, there is %d entries"%N)

        dx=int(info["max_dim_x"])
        if dx > 1:
            self.logger.debug("Attribute is a vector with max dimension = %s"%dx)
            N=N*dx

        # If data chunk is too much, we need to cut it
        if N > Nmax:
            dt = (dateStop-dateStart)/(N//Nmax)
            cdates = [dateStart]
            while cdates[-1] < dateStop:
                cdates.append(cdates[-1]+dt)
            cdates[-1] = dateStop
            self.logger.debug("Cutting access to %d little chunks of time, %s each."%(len(cdates)-1, dt))
        else:
            cdates=[dateStart, dateStop]

        return cdates

    def findattr(self, pattern):
        """
        Search for an attribute path using the pattern given.
        Case insensitive.

        Parameters:
        -----------
        pattern: str
            Pattern to search, wildchar * accepted.
            example "dg*dcct*current"

        Returns:
        --------
        results: (str,)
            List of string match
        """
        keywords=pattern.lower().split('*')

        matches = [attr for attr in self.attr_table if all(k in attr.lower() for k in keywords)]

        return matches


    def infoattr(self, attribute):
        """
        Get informations for an attribute and pack it into a python dict.

        Parameters
        ----------
        attribute : String
            Name of the attribute. Full Tango name i.e. "test/dg/panda/current".

        Returns
        -------
        info : dict
            Dictionnary of propertyname:propertyvalue
        """
        info = dict()

        for func in ("GetAttDefinitionData", "GetAttPropertiesData"):
            R=getattr(self.extractor, func)(attribute)
            if not R is None:
                for i in R:
                    _s=i.split("::")
                    info[_s[0]]=_s[1]
            else:
                self.logger.warning("Function %s on extractor returned None"%func)


        return info