
import logging
from operator import attrgetter, concat
from collections import namedtuple

import numpy as np
import pandas as pd

from .. import archive
from . import ProblemType

log = logging.getLogger(__name__)


class Monitor( object ):
    """ abstract class keep track of a set of parameters in the model or the
    learning process"""

    ftype_frm = {int: "%d", float: "%.6f", str: "%s"}
    fzero = {int: 0, float: 0.0, str: "NA"}

    @property
    def csv(self, sep=","):
        return sep.join( [self.ftype_frm[t] for t in self._ftypes] ) % self

    @classmethod
    def _empty(cls):
        return tuple.__new__(cls, [cls.fzero[t] for t in cls._ftypes])

    @property
    def report(self):
        return self._report_frm.format(**self._asdict())

    @classmethod
    def _has_changed(cls, what, patience, data, summ_f, cmp_f):
        """compares a summary of the last `patience` values defined by the `what` field
        with the `patience`-1 value"""

        if len(data) < patience + 1:
            raise ValueError("history too short(%d) for this patience(%d)" % (len(data),
                                                                              patience))

        what_f = attrgetter(what)

        ##extract the values defined by `what` from the last patience epochs
        patience_best = summ_f(map(what_f, data[-patience:]))
        ##extract the value just before the last `patience` items
        history_mark = what_f(data[-patience - 1])

        assert type(patience_best) == type(history_mark)
        log.debug("%f, %f [%s]", history_mark, patience_best,
                  str(map(what_f, data[-patience:])))

        return cmp_f(patience_best, history_mark)

    @classmethod
    def rel_diff(cls, init, final):
        "(final-init)/max(final, init)"

        M = float( max(final, init) )
        if M:
            return (final - init) / M
        return 0.

    @classmethod
    def is_min_up(cls, what, memory, data):
        return cls._has_changed(what, memory, data, min,
                                lambda summarized, mark: summarized > mark)

    @classmethod
    def is_min_down(cls, what, memory, data):
        return cls._has_changed(what, memory, data, min,
                                lambda summarized, mark: summarized < mark)

    @classmethod
    def is_min_still(cls, what, memory, data):
        return cls._has_changed(what, memory, data, min,
                                lambda summarized, mark: summarized == mark)

    @classmethod
    def is_max_up(cls, what, memory, data):
        return cls._has_changed(what, memory, data, max,
                                lambda summarized, mark: summarized > mark)

    @classmethod
    def is_max_down(cls, what, memory, data):
        return cls._has_changed(what, memory, data, max,
                                lambda summarized, mark: summarized < mark)

    @classmethod
    def is_max_still(cls, what, memory, data):
        return cls._has_changed(what, memory, data, max,
                                lambda summarized, mark: summarized == mark)

    @classmethod
    def _dump(cls, path, data):
        with open(path, 'w') as fd:
            log.info("saving %s log to %s ..." % (str(cls), fd.name))
            fd.write(",".join( cls._fields ) + "\n")
            map(lambda r: fd.write("%s\n" % r.csv), data)

    @classmethod
    def cls_name(cls):
        return str(cls).split(".")[-1][:-2]

    @classmethod
    def _archive_dump(cls, path, train_run_name, data):
        key = "%s/%s/%s" % (archive.basename(path), train_run_name,
                            cls.cls_name())
        obj = pd.DataFrame.from_records(data, columns=cls._fields)

        archive.save_object( archive.archname(path), key, obj )

    @classmethod
    def _archive_load(cls, path, train_run_name):
        key = "%s/%s/%s" % (archive.basename(path), train_run_name,
                            cls.cls_name())
        return archive.load_object(archive.archname(path), key)


class DaeLearnMonitor(Monitor,
                      namedtuple("learnmonitor", ("epoch lepoch lrate layer "
                                                  "traincost validcost "
                                                  "pv_traincost pv_validcost"))):
    _ftypes = (int, int, float, int, float, float, float, float)
    _report_frm = ("Epoch {lepoch} (layer {layer}) / lrate: {lrate:.6f}\n\t"
                   "Cost: {traincost:.6f} / {validcost:.6f}\n\t"
                   "Cost/var: {pv_traincost:.6f} / {pv_validcost:.6f}\n")

    @classmethod
    def _from_fs(cls, ds, mtr, epoch, lepoch, lidx, cost_f, N):
        train_batches = ds.train_batches
        valid_batches = ds.valid_batches

        ct = np.array( map(cost_f, train_batches) ).mean()
        cv = np.array( map(cost_f, valid_batches) ).mean()
        # reconstruction cost is summed over all variables
        # so we devide here by the size of the input
        #N = np.prod(ds.X.shape[1:])

        return cls._make( [epoch, lepoch, mtr.lr, lidx,
                           ct, cv, ct / N, cv / N] )


def _cnnlearn_monitor_from_fs(ds, mtr, epoch, cost_f, ce_f, mcl_f, pt):
    train_batches = ds.train_batches
    valid_batches = ds.valid_batches

    train_s = ds.batch_size * len(ds.train_batches)
    valid_s = ds.batch_size * len(ds.valid_batches)

    train_cost = np.mean(np.array(map(cost_f, train_batches)))
    valid_cost = np.mean(np.array(map(cost_f, valid_batches)))
    train_ce = np.mean(np.array(map(ce_f, train_batches)))
    valid_ce = np.mean(np.array(map(ce_f, valid_batches)))

    train_mc, valid_mc = 0.0, 0.0
    if pt == ProblemType.classification:
        train_mc = sum(map(mcl_f, train_batches)) / float(train_s)
        valid_mc = sum(map(mcl_f, valid_batches)) / float(valid_s)
    return [epoch, mtr.lr, train_cost, train_ce, train_mc, valid_cost, valid_ce, valid_mc]


class CnnLearnMonitorC(Monitor, namedtuple("learnmonitor",
                                           ("epoch lrate traincost trainCE trainMC "
                                            "validcost validCE validMC"))):
    _ftypes = (int, float,
               float, float, float,
               float, float, float)

    _report_frm = ("Epoch {epoch} / lrate: {lrate:.4f}\n\t"
                   "Cost: {traincost:.8f} / {validcost:.8f} \n\t"
                   "CE: {trainCE:.8f} / {validCE:.8f} \n\t"
                   "Missclass.: {trainMC:.2%} / {validMC:.2%}")
    @classmethod
    def _from_fs(cls, ds, mtr, epoch, cost_f, ce_f, mcl_f, pt):
        return cls._make(_cnnlearn_monitor_from_fs(ds, mtr, epoch,
                                                   cost_f, ce_f, mcl_f, pt))


class CnnLearnMonitorR(Monitor, namedtuple("learnmonitor",
                                           ("epoch lrate traincost trainRMS "
                                            "validcost validRMS"))):
    _ftypes = (int, float,
               float, float,
               float, float)

    _report_frm = ("Epoch {epoch} / lrate: {lrate:.4f}\n\t"
                   "Cost: {traincost:.8f} / {validcost:.8f} \n\t"
                   "RMS: {trainRMS:.8f} / {validRMS:.8f}")

    @classmethod
    def _from_fs(cls, ds, mtr, epoch, cost_f, ce_f, mcl_f, pt):
        res = _cnnlearn_monitor_from_fs(ds, mtr, epoch, cost_f, ce_f, mcl_f, pt)
        return cls._make((res[0], res[1], res[2], res[3], res[5], res[6]))


class MlpLearnMonitorC(CnnLearnMonitorC):
    pass


class MlpLearnMonitorR(CnnLearnMonitorR):
    pass


class LearnMonitor(Monitor, namedtuple("learnmonitor",
                                       ("epoch lrate traincost trainCE "
                                        "trainMC validcost validCE validMC"))):
    """learning stats"""

    _ftypes = (int, float,
               float, float,
               float, float,
               float, float)
    _report_frm = ("Epoch {epoch} / lrate: {lrate:.4f}\n\t"
                   "Cost: {traincost:.8f} / {validcost:.8f} \n\t"
                   "CE/RMS: {trainCE:.8f} / {validCE:.8f} \n\t"
                   "Missclass.: {trainMC:.2%} / {validMC:.2%}")

    @classmethod
    def _from_fs(cls, ds, mtr, epoch, cost_f, ce_f, mcl_f, pt):
        train_batches = ds.train_batches
        valid_batches = ds.valid_batches

        train_s = ds.batch_size * len(ds.train_batches)
        valid_s = ds.batch_size * len(ds.valid_batches)

        train_cost = np.mean(np.array(map(cost_f, train_batches)))
        valid_cost = np.mean(np.array(map(cost_f, valid_batches)))
        train_ce = np.mean(np.array(map(ce_f, train_batches)))
        valid_ce = np.mean(np.array(map(ce_f, valid_batches)))

        train_mc, valid_mc = 0.0, 0.0
        if pt == ProblemType.classification:
            train_mc = sum(map(mcl_f, train_batches)) / float(train_s)
            valid_mc = sum(map(mcl_f, valid_batches)) / float(valid_s)
        return cls._make( [epoch, mtr.lr,
                           train_cost, train_ce, train_mc,
                           valid_cost, valid_ce, valid_mc] )


class WeightMonitor(Monitor, namedtuple("w_monitor",
                                        ("epoch layer wshp wmin wmean wmedian wsd wmax "
                                         "bshp bmin bmean bmedian bsd bmax"))):
    """weight and activity information on the network"""

    _ftypes = (int, int,
               str, float, float, float, float, float,
               str, float, float, float, float, float)
    _report_frm = ("{epoch:d} {layer:d}"
                   "W({wshp:s}): {wmin:.6f} {wmean:.6f} {wmedian:.6f} {wsd:.6f} {wmax:.6f}\n\t"
                   "B({bshp:s}): {bmin:.6f} {bmean:.6f} {bmedian:.6f} {bsd:.6f} {bmax:.6f}")

    @classmethod
    def _from_model(cls, epoch, l, model):
        def stat_record(a):
            return map(lambda f: f(a), (np.min, np.mean, np.median, np.std, np.max))

        def shape_str(a):
            return str(np.prod(a.shape))

        w_lst = model[l].get_weights()
        w_shapes = map(shape_str, w_lst)
        w_stats = map(stat_record, w_lst)
        w_rec = map(lambda (a, b): [a] + b, zip(w_shapes, w_stats))

        return cls._make( [epoch, l] + reduce(concat, w_rec) )


if __name__ == "__main__":
    for C in (WeightMonitor, LearnMonitor):
        frm_str = "%s) C._ftypes %d != C._fields = %d"
        assert len(C._ftypes) == len(C._fields), frm_str % (str(C),
                                                            len(C._ftypes),
                                                            len(C._fields))
