'''
a model template is based on deep NN model. for each layer i, it
maintains or computes

 * decision_boundary(i) the representation of the
   training data at the output of layer i.

 * constraint_boundary(i) the representation
   of the training data with label self.label
   at the output of layer i

 * feature_templ(i) the a template input that layer i
   would map at a signal satisfying the constraint_boundary(i)
'''

import numpy as np
import pandas as pd
import logging
import gurobipy as grb

from dimer import archive
from . import unpool

log = logging.getLogger(__name__)


class BaseMTmpl( object ):
    """Basic abstract class for a model template

    keeps references to the model and implementes memoization
    for layer_states, layer_constraint_boundaries and theano activation functions"""

    def __init__(self, ds, model, bounds, inv_cls, inv_act, objdirs, slack):
        assert len(model) == len(bounds) == len(inv_cls) == \
            len(inv_act) == len(objdirs)

        self.ds = ds
        self.model = model
        self.bounds = bounds
        self.inv_cls = inv_cls
        self.inv_act = inv_act
        self.objdirs = objdirs
        self.slack = slack
        self._grbm = {}

    def from_archive(self, path, trid, li):
        "load a layer template from an archive"

        log.info("loading template (%d) ... ", li)
        f, k = self.__path(path, trid, li, self.label)
        obj = archive.load_object(f, k)
        return obj.values.reshape( map(int, obj.name.split("_")) )

    @staticmethod
    def __path(path, trid, layer, label):
        """
        :param path: dset_path
        :param trid: trainid
        :param layer: the layer index (0, 1, ...)
        :rtype: (file path, path_inside_archive) to the template of this layer
        """

        return (archive.archname(path),
                "/".join((archive.basename(path),
                          trid,
                          "layer_%d" % layer,
                          "template_%s" % str(label))))

    def dump(self, path, trid, i=None):
        """dump templates to a model archive"""

        if i is None:
            ilst = range(len(self.cnnmod))
        else:
            ilst = [i]

        for li in ilst:
            log.info("save template (%d) ... ", li)
            f, k = self.__path(path, trid, li, self.label)
            x = self.features(li)
            archive.save_object(f, k, pd.Series(x.flatten(),
                                name="%s" % "_".join(map(str, x.shape))))

    def cb(self, i):
        "constraint bondary. "

        x = self.model.compute_state(i, self.ds.X[self.ds.T == self.label].mean(axis=0))
        if i == len(self.model) - 1:  # top layer, apply slack
            for j in range(x.shape[1]):
                if j == self.label:
                    x[0, j] -= self.slack
                else:
                    x[0, j] += self.slack
        return x

    def db(self, i):
        """decision boundary for layer i, is the input that when propagated
        to the top layer yields maximum cross entropy."""

        if i == -1:  # mean of the data
            return self.ds.X.mean(axis=0)
        return self.model.compute_state(i, self.ds.X.mean(axis=0))

    def _get_feats_(self, i):
        osh = self.model.osh[i]
        if i == len(self.model) - 1:  # feats = slacked constraint boundary
            return self.cb(i).reshape(osh)
        else:  # feats = previous features
            return self.features(i + 1).reshape(osh)

    def _get_i_params_(self, i):
        ish = self.model.ish[i]
        if i == 0:
            icb = np.mean(self.ds.X[self.ds.T == self.label], axis=0).reshape(ish)
            idb = np.mean(self.ds.X, axis=0).reshape(ish)
        else:
            icb = self.cb(i - 1).reshape(ish)
            idb = self.db(i - 1).reshape(ish)
        return (icb, idb)

    def x_cb_cost(self, i):
        return self._grbm[i].x_cb_cost()

    def feature_stats(self, i):
        """compute descriptive statistics on the extracted features

        :param int i: layer index
        :rtype: a namedtuple with statistics"""

        db = self.db(i - 1).flatten()
        x = self.features(i).flatten()
        cb = self.cb(i - 1).flatten()
        assert x.shape == cb.shape == db.shape
        diff = np.abs(x - db) - np.abs(cb - db)

        from collections import namedtuple
        fstats = namedtuple("FeatureStats",
                            ("layer label flat_dim real_dim "
                             "feat_norm sig_norm diff_min diff_max diff_mean "
                             "feat_prob sig_prob"))
        px, pcb = np.copy(x), np.copy(cb)
        for j in range(i, len(self.model)):
            px = self.model.compute_repr(j, px)
            pcb = self.model.compute_repr(j, pcb)

        return fstats._make((i, self.label, db.shape[0], self.model.ish[i],
                             np.abs(x-db).sum(), np.abs(cb-db).sum(),
                             np.min(diff), np.max(diff), np.mean(diff),
                             px.flatten()[self.label], pcb.flatten()[self.label]))


class CnnMTempl( BaseMTmpl ):
    """automate feature extraction for a whole model (stack of layers)

    there is a fair amount of settings

    :param cnnm: Cnn model
    :param ds: dataset
    :param label: the label with respect to which we are solving
    :param bounds: variable bounds -- these define the domain of variables
    :param inv_cls: class to invert the layers. these should
        correspond to the layer type (see gmodel)
    :param inv_act: inverse of the activation function
    :param objdirs: minmize/maximize
    :param ignore_eqcond: whether to ignore (i.e., not set)
        conditions for which there is an equality sign.
    :param ignore_eqvar: whether to ignore (i.e., not set)
        variable bounds for vars equal to decision boundary
    :param feat_ocb: set the constraint boundary = to the feature template
        of previous layer
    :param respect_icb: set variable bounds as: |db - x|  < |db - cb|.
        where db and cb refer to previous layer
    :param ce_slack: achieve this much less CE in output than the training data achieves
    """

    def __init__(self, cnnm, ds, label,
                 bounds, inv_cls, inv_act, objdirs,
                 ignore_eqcond, ignore_eqvar, weight_objvars, feat_ocb, respect_icb,
                 ce_slack):

        BaseMTmpl.__init__(self, ds, cnnm, bounds, inv_cls,
                           inv_act, objdirs, ce_slack)
        self.ignore_eqcond = ignore_eqcond
        self.ignore_eqvar = ignore_eqvar
        self.weight_objvars = weight_objvars
        self.feats_ocb = feat_ocb
        self.respect_icb = respect_icb
        self.label = label

    def _old_features(self, i):
        ish, osh = self.model.ish[i], self.model.osh[i]

        if not(i in self._grbm):
            feats = self._get_feats_(i)
            ocb = self.cb(i).reshape(osh)
            odb = self.db(i).reshape(osh)
            icb, idb = self._get_i_params_(i)
            log.info("solving layer %d", i)

            if i <= self.model.top_cp_idx:
                psh = self.model[i].pshape
                ocb = unpool(ocb, psh)
                odb = unpool(odb, psh)
                feats = unpool(feats, psh)
            gm = self.inv_cls[i](self.model[i],
                                 (idb, odb), (icb, ocb),
                                 feats,
                                 self.bounds[i],
                                 self.inv_act[i],
                                 self.objdirs[i],
                                 self.ignore_eqcond[i],
                                 self.ignore_eqvar[i],
                                 self.weight_objvars[i],
                                 self.feats_ocb[i],
                                 self.respect_icb[i],
                                 )
            gm.solve()
            self._grbm[i] = gm
        return self._grbm[i].get_sol().reshape(ish)

    def features(self, i):
        """get the input-features for this layer

        :param int i: index of layer
        :rtype: ndarray of extracted features"""

        ish, osh = self.model.ish[i], self.model.osh[i]

        if not(i in self._grbm):
            feats = self._get_feats_(i)
            ocb = self.cb(i).reshape(osh)
            odb = self.db(i).reshape(osh)
            icb, idb = self._get_i_params_(i)
            log.info("solving layer %d", i)

            self._grbm[i] = self.inv_cls[i](self.model[i],
                                            (idb, odb), (icb, ocb),
                                            feats,
                                            self.bounds[i],
                                            self.inv_act[i],
                                            self.objdirs[i],
                                            self.ignore_eqcond[i],
                                            self.ignore_eqvar[i],
                                            self.weight_objvars[i],
                                            self.feats_ocb[i],
                                            self.respect_icb[i],
                                            )
            self._grbm[i].model.optimize()
        try:
            solution = self._grbm[i].get_sol().reshape(ish)
            if i <= self.model.top_cp_idx:
                solution = unpool(solution, self.model[i].pshape)
        except grb.GurobiError as verr:
            log.error(verr.message)
            return self._grbm[i].cbi
        return solution


class MlpMTempl(BaseMTmpl):
    def __init__(self, mlpm, ds, label,
                 bounds, inv_cls, inv_act, objdirs,
                 ignore_eqcond, ignore_eqvar, weight_objvars,
                 feat_ocb, respect_icb, ce_slack):

        BaseMTmpl.__init__(self, ds, mlpm, bounds, inv_cls,
                           inv_act, objdirs, ce_slack)

        self.ignore_eqcond = ignore_eqcond
        self.ignore_eqvar = ignore_eqvar
        self.weight_objvars = weight_objvars
        self.feats_ocb = feat_ocb
        self.respect_icb = respect_icb
        self.label = label

    def features(self, i):
        """get the input-features for this layer

        :param int i: index of layer
        :rtype: ndarray of extracted features"""

        ish, osh = self.model.ish[i], self.model.osh[i]
        if not(i in self._grbm):
            feats = self._get_feats_(i)
            ocb = self.cb(i).reshape(osh)
            odb = self.db(i).reshape(osh)
            icb, idb = self._get_i_params_(i)
            log.info("solving layer %d", i)
            gm = self.inv_cls[i](self.model[i],
                                 (idb, odb), (icb, ocb),
                                 feats,
                                 self.bounds[i],
                                 self.inv_act[i],
                                 self.objdirs[i],
                                 self.ignore_eqcond[i],
                                 self.ignore_eqvar[i],
                                 self.weight_objvars[i],
                                 self.feats_ocb[i],
                                 self.respect_icb[i],
                                 )
            gm.solve()
            self._grbm[i] = gm
        return self._grbm[i].get_sol().reshape(ish)
