''' Created on Jul 15, 2013 '''

import logging
log = logging.getLogger(__name__)
import abc
from itertools import product

import numpy as np
rng = np.random.RandomState()
try:
    import gurobipy as grb
except ImportError:
    log.critical("You need to have gurobi installed, to use this module")
import dimer
from . import tup_prod



class GrbModelHolder( object ):
    "provides a simple interface to a guroby model"

    __metaclass_ = abc.ABCMeta
    OBJMIN = grb.GRB.MINIMIZE
    OBJMAX = grb.GRB.MAXIMIZE

    def __init__(self, name):
        self.model = grb.Model(name)
        #self.model.setParam(grb.GRB.Param.LogToConsole, 0)

    @abc.abstractmethod
    def set_variables(self, ignore_eqcond, respect_icb):
        raise NotImplementedError

    @abc.abstractmethod
    def set_constraints(self, ignore_eqcond):
        raise NotImplementedError

    @abc.abstractmethod
    def set_objective(self, weight_vars):
        raise NotImplementedError

    def relax(self, relaxobjtype=2, minrelax=True):
        """
        from http://www.gurobi.com/documentation/5.5/reference-manual/node565
        Modifies the Model object to create a feasibility relaxation.
        Note that you need to call optimize on the result to compute the
        actual relaxed solution. Note also that this is a simplified version
        of this method - use feasRelax for more control over the
        relaxation performed.

        The feasibility relaxation is a model that, when solved, minimizes the
        amount by which the solution violates the bounds and linear constraints
        of the original model. This method provides a number of options for
        specifying the relaxation.

        If you specify relaxobjtype=0, the objective of the feasibility
        relaxation is to minimize the sum of the magnitudes of the bound
        and constraint violations.

        If you specify relaxobjtype=1, the objective of the feasibility
        relaxation is to minimize the sum of the squares of the bound
        and constraint violations.

        If you specify relaxobjtype=2, the objective of the feasibility
        relaxation is to minimize the total number of bound and constraint
        iolations.

        To give an example, if a constraint is violated by 2.0, it would
        contribute 2.0 to the feasibility relaxation objective for
        relaxobjtype=0, it would contribute 2.0*2.0 for relaxobjtype=1,
        and it would contribute 1.0 for relaxobjtype=2, and

        The minrelax argument is a boolean that controls the type of
        feasibility relaxation that is created. If minrelax=False, optimizing
        the returned model gives a solution that minimizes the cost of the
        violation. If minrelax=True, optimizing the returned model finds a
        solution that minimizes the original objective, but only from among
        those solutions that minimize the cost of the violation. Note that
        feasRelaxS must solve an optimization problem to find the minimum
        possible relaxation when minrelax=True, which can be quite expensive.


        :param int relaxobjtype: The cost function used when finding the
        minimum cost relaxation.

        :param int minrelax: The type of feasibility relaxation to perform.

        :param int vrelax: Indicates whether variable bounds can be relaxed.

        :param int crelax: Indicates whether constraints can be relaxed."""

        self.model.feasRelaxS(relaxobjtype, minrelax, False, True)

    def solve(self, **relax_args):
        """call optimize and relax in case of infeasibility

        :param dict relax_args: keyword arguments to the relax function
        """

        self.model.optimize()
        if self.model.status == grb.GRB.status.INFEASIBLE:
            log.warning("***###*** MODEL (%s) INFEASIBLE ***###***", self.model.ModelName)
            log.warning("***###*** don't worry! finding min. relaxation ***###***")
            self.relax(**relax_args)
            self.model.optimize()


class LayerModel( object ):
    """abstract class for a formulation over a layer"""

    def __init__(self, layer):
        self.layer = layer
        self.ltype = type(layer)
        if not (self.ltype) in (dimer.nnet.nccn.ConvPoolLayer,
                                dimer.nnet.nccn.HiddenLayer,
                                dimer.nnet.nccn.LinearRegression,
                                dimer.nnet.nccn.LogisticReg):
            raise ValueError("unsupported type: %s" % self.ltype)

    @property
    def is_cp(self):
        "is this layer convolutional?"

        return self.ltype == dimer.nnet.nccn.ConvPoolLayer

    @property
    def is_hl(self):
        "is this layer Hidden?"
        return self.ltype == dimer.nnet.nccn.HiddenLayer

    @property
    def is_regr(self):
        "is this a linear regression layer?"
        return self.ltype == dimer.nnet.nccn.LinearRegression

    @property
    def is_logreg(self):
        "is this a logistic regression layer?"
        return self.ltype == dimer.nnet.nccn.LogisticReg

    def var_idx_iter(self):
        "iterate over variables"
        return product(*map(xrange, self.dbi.shape))

    def cstr_idx_iter(self):
        "iterate over constraints"
        return product(*map(xrange, self.dbo.shape))


class NormOpt(GrbModelHolder, LayerModel):
    """a formulation that optimizest the norm of the extracted features"""

    # TODO : add constants for ignore_eqcond and ignore_eqvar. eg: SET_TO_BOUNDS etc
    def __init__(self, layer, db, cb, fi, bounds, inv_activation, objdir,
                 ignore_eqcond, ignore_eqvar, weight_objvars,
                 feats_as_ocb, respect_cbi):
        """initialize Norm Optimizer instance

        :param layer: the CNN layer
        :param tuple db: decision boundary in input and output.
            tuple of the form (dbi, dbo)
        :param tuple cb: constraint boundary in input and output.
            tuple of the form (cbi, cbo)
        :param ndarray fi: features inferred by optimizing the layer above.
        :param tuple bounds: bounds of variables when the constraint
            boundary does not apply.  of the type (lb, ub)
        :param function inv_activation: the inverse of the activation
            of this layer
        :param int objdir: direction of the objective function
        :param bool ignore_eqcond: whether to ignore constraints
            corresponding to cb == db
        :param bool ignore_eqvar: whether to ignore (i.e., set to
            self.bounds)
            variable bounds corresponding to cb == db
        :param bool feats_as_ocb: constraint boundary in output is set
            to computed features on the previous layer (i.e., cb[1] or cbo)
        """

        assert type(bounds) == type(db) == type(cb) == tuple
        assert len(bounds) == len(db) == len(cb) == 2

        LayerModel.__init__(self, layer)
        self.dbi, self.dbo = db
        self.cbi, self.cbo = cb
        self.feats = fi
        if feats_as_ocb:
            self.cbo = np.copy(self.feats)

        self.dbi.flags.writeable = False
        self.cbi.flags.writeable = False
        self.dbo.flags.writeable = False
        self.cbo.flags.writeable = False
        self.feats.flags.writeable = False

        self.bounds = bounds
        self.inv_activation = inv_activation
        self.objdir = objdir

        assert self.dbi.shape == self.cbi.shape
        assert self.dbo.shape == self.cbo.shape
        GrbModelHolder.__init__(self, str(layer) )

        self.xvars = {}
        self.set_variables(ignore_eqvar, respect_cbi)
        self.set_constraints(ignore_eqcond)
        self.set_objective(weight_objvars)
        log.info("%s", str(self.model))

    def __conv_op(self, w, k, f, i, j):
        def co( (xi, wi) ):
            return self.xvars[0, f, i + xi[0], j + xi[1]] * float(w[k, f, wi[0], wi[1]])
        return co

    def set_constraints(self, ignore_eqcond):
        """define constraints"""

        W, b = self.layer.get_weights()

        if self.is_cp:
            Z = self.inv_activation(self.cbo)
            (m, n) = W.shape[2:]
            w_idx = list( tup_prod((m, n), flip=True) )
            x_idx = list( tup_prod((m, n)) )
            conv_idx = list( zip(x_idx, w_idx) )

            def lhs_f((zero, kern, r, c), W=W):
                lhs = 0
                for f in range(W.shape[1]):
                    lhs += grb.quicksum( map(self.__conv_op(W, kern, f, r, c), conv_idx) )
                return lhs
            self.__constr(self.cstr_idx_iter(),
                          lhs_f,
                          lambda (zero, k, i, j): float(Z[zero, k, i, j] - b[k]),
                          "%d_%d_%d_%d", ignore_eqcond)
        elif self.is_hl:
            Z = self.inv_activation(self.cbo)
            if np.any( np.isinf(Z) ):
                MAX_SIGMOID = 10.
                log.warning("saturating values in layer %s", str(self.layer))
                log.warning("clipping to %.f", MAX_SIGMOID)
                Z[np.where(np.logical_and(np.isinf(Z), Z > 0))] = MAX_SIGMOID
                Z[np.where(np.logical_and(np.isinf(Z), Z < 0))] = -MAX_SIGMOID

            def lhs_f((zero, c), W=W):
                return grb.quicksum(
                    map(lambda (j, i): self.xvars[j, i] * float(W[i, c]),
                        self.var_idx_iter()))
            self.__constr(self.cstr_idx_iter(),
                          lhs_f,
                          lambda (zero, c): float(Z[zero, c] - b[c]),
                          "%d_%d", ignore_eqcond)
        else:
            Z = self.inv_activation(self.cbo,
                                    np.exp(np.dot(self.cbi, W) + b).sum())

            def lhs_f((zero, c), W=W):
                return grb.quicksum(
                    map(lambda (j, i): self.xvars[j, i] * float(W[i, c]),
                        self.var_idx_iter()))
            self.__constr(self.cstr_idx_iter(),
                          lhs_f,
                          lambda (zero, c): float(Z[zero, c] - b[c]),
                          "%d_%d", ignore_eqcond)
        self.model.update()

    def __constr(self, idx_iter, lhs_f, rhs_f, cstr_nfrm, ignore_eqcond):
        """called by self.constraints. this is the same for all types of layers"""

        for idx in idx_iter:
            lhs = lhs_f(idx)
            rhs = rhs_f(idx)
            if self.cbo[idx] > self.dbo[idx]:
                self.model.addConstr( lhs >= rhs, cstr_nfrm % idx )
                log.debug( "%s >= %s %s", lhs, rhs, (cstr_nfrm % idx) )
            elif self.cbo[idx] < self.dbo[idx]:
                self.model.addConstr( lhs <= rhs, cstr_nfrm % idx )
                log.debug( "%s <= %s %s", lhs, rhs, (cstr_nfrm % idx) )
            else:
                if ignore_eqcond:
                    log.debug("relaxing constraint %s", (cstr_nfrm % idx))
                else:
                    self.model.addConstr( lhs == rhs, cstr_nfrm % idx )
                    log.debug( "%s == %s %s", lhs, rhs, (cstr_nfrm % idx) )

    def var_nfrm(self):
        "variable name format"

        return "x%s" % "_".join(("%d",) * len(self.dbi.shape))

    def set_objective(self, weight_vars):
        """minimize sum of variables"""

        idx_iter = self.var_idx_iter()
        terms = []
        for idx in idx_iter:
            mult = float(self.cbi[idx] - self.dbi[idx])
            if mult and (not weight_vars):
                mult /= abs(mult)
            terms.append(mult * (self.xvars[idx] - float(self.dbi[idx])))

        self.model.setObjective(grb.quicksum( terms ), self.objdir)
        self.model.update()

    def set_variables(self, ignore_eqcond, respect_icb):
        "allocate variables"

        name_frmt = self.var_nfrm()
        far_bound = np.copy(self.cbi)
        if not respect_icb:
            far_bound[self.cbi < self.dbi] = self.bounds[0]
            far_bound[self.cbi > self.dbi] = self.bounds[1]

        for idx in self.var_idx_iter():
            if self.cbi[idx] > self.dbi[idx]:
                b, s = (self.dbi[idx], far_bound[idx]), "^"
            elif self.cbi[idx] < self.dbi[idx]:
                b, s = (far_bound[idx], self.dbi[idx]), "_"
            else:
                if ignore_eqcond:
                    b, s = self.bounds, "*"
                else:
                    b, s = (far_bound[idx], far_bound[idx]), "."
            self.xvars[idx] = self.model.addVar(lb=float(b[0]), ub=float(b[1]),
                                                vtype="C",
                                                name=name_frmt % idx )
            log.debug("%s : %f <= %s <= %f", s, b[0], name_frmt % idx, b[1])
        self.model.update()

    def x_cb_cost(self):
        p = self.cbi > self.dbi
        m = self.cbi < self.dbi
        u = self.cbi == self.dbi
        CB = self.cbi
        x_cost = self.model.ObjVal
        cb_cost = (np.sum(CB[p] - self.dbi[p]) +
                   np.sum(self.dbi[m] - CB[m]) +
                   np.sum(np.abs(CB[u] - self.dbi[u])))
        return (x_cost, cb_cost)

    def get_sol(self):
        ish = self.dbi.shape
        _X = np.empty( ish )
        for i in self.var_idx_iter():
            _X[i] = self.xvars[i].getAttr("X")
        return _X


class ThrDist(GrbModelHolder, LayerModel):
    """ model for a layer: Y = f(X; W,b)

    Formulation is

    min |~X - dbi|
    s.t sigma(~XW + b) = dbo,  lb < X < ub

    let (~X - dbi) = X
    min |X|
    s.t. XW = sigma^-1(dbo) - b + dbiW, lb - dbi < X < ub - dbi

    let |X| = Xp + Xm and X = Xp - Xm
    min Xp + Xm
    s.t. (Xp - Xm)W = sigma^-1(dbo) - b + dbiW
    where 0 < Xm < -(lb - dbi), 0 < Xp < ub - dbi

    and dbi and dbo are, possibly, mean arrays of training data

    This class implements that using the trick in
    http://www.aimms.com/aimms/download/manuals/aimms3om_linearprogrammingtricks.pdf
    page 63-64"""

    def __init__(self, layer, dbi, dbo, bounds, inv_activation):
        self.mv = {}
        self.pv = {}
        self.dbi = dbi
        self.dbo = dbo
        log.info("%s ---> %s", str(dbi.shape), str(self.dbo.shape))
        self._X = None
        self.invact = inv_activation
        LayerModel.__init__(self, layer)

        GrbModelHolder.__init__(self, str(layer) )
        self.set_variables(*bounds)
        self.set_constraints()
        self.set_objective()

    @staticmethod
    def _var_bounds(thr, lb, ub):
        if thr - lb < 0:
            pb = (lb - thr, ub - thr)
            mb = (0, 0)
        elif ub - thr < 0:
            pb = (0, 0)
            mb = (thr - ub, thr - lb)
        else:
            pb = (0, ub - thr)
            mb = (0, thr - lb)
        return (pb, mb)

    def var_nfrm(self, sign):
        return "x%s_%s" % (sign, "_".join(("%d",) * len(self.dbi.shape)))

    def set_variables(self, lb, ub, vt="C"):
        name_frmt_p, name_frmt_m = self.var_nfrm("p"), self.var_nfrm("n")
        for idx in self.var_idx_iter():
            pb, mb = self._var_bounds(self.dbi[idx], lb, ub)
            self.pv[idx] = self.model.addVar(lb=pb[0], ub=pb[1],
                                             vtype=vt, name=name_frmt_p % idx )
            self.mv[idx] = self.model.addVar(lb=mb[0], ub=mb[1],
                                             vtype=vt, name=name_frmt_m % idx )
            log.debug("%f <= %s <= %f and %f <= %s <= %f",
                      mb[0], name_frmt_m % idx, mb[1],
                      pb[0], name_frmt_p % idx, pb[1])
        self.model.update()

    def __constr(self, idx_iter, lhs_f, rhs_f, cstr_nfrm):
        """called by self.constraints. this is the same for all types of layers"""

        for idx in idx_iter:
            lhs = lhs_f(idx)
            rhs = rhs_f(idx)
            self.model.addConstr(lhs == rhs, cstr_nfrm % idx)
            log.debug("%s == %s %s", lhs, rhs, (cstr_nfrm % idx))

    def __conv_op(self, W, k, f, i, j):
        def co( (xi, wi) ):
            x = (self.pv[0, f, i + xi[0], j + xi[1]] -
                 self.mv[0, f, i + xi[0], j + xi[1]] +
                 float(self.dbi[0, f, i + xi[0], j + xi[1]]))
            return x * float(W[k, f, wi[0], wi[1]])
        return co

    def set_constraints(self):
        W, b = self.layer.get_weights()
        Z = self.invact(self.dbo)

        if self.is_cp:
            (m, n) = W.shape[2:]
            w_idx = list( tup_prod((m, n), flip=True) )
            x_idx = list( tup_prod((m, n)) )
            conv_idx = list( zip(x_idx, w_idx) )

            def lhs_f((zero, kern, r, c), W=W):
                lhs = 0
                for f in range(W.shape[1]):
                    lhs += grb.quicksum( map(self.__conv_op(W, kern, f, r, c), conv_idx) )
                return lhs

            self.__constr(self.cstr_idx_iter(),
                          lhs_f,
                          lambda (zero, k, i, j): float(Z[zero, k, i, j] - b[k]),
                          "%d_%d_%d_%d")
        else:

            def lhs_f((zero, c), W=W):
                return grb.quicksum(
                    map(lambda (j, i): (self.pv[j, i] - self.mv[j, i]) * float(W[i, c]),
                        self.var_idx_iter()))

            self.__constr(self.cstr_idx_iter(),
                          lhs_f,
                          lambda (zero, c): float(Z[zero, c] - b[c] - np.dot(self.dbi, W[:, c])),
                          "%d%d")
        self.model.update()

    def set_objective(self):
        self.model.setObjective(grb.quicksum(self.mv.values()) +
                                grb.quicksum(self.pv.values()),
                                self.OBJMIN)
        self.model.update()

    def _x(self, idx_iter):
        def _get_s((m_, p_)):
            p, m = p_.getAttr("x"), m_.getAttr("x")
            if p > 0 and m == 0:
                return p
            elif p == 0 and m > 0:
                return -m
            elif p == 0 and m == 0:
                return 0.0
            raise ValueError("p=%f, m=%f" % (p, m))

        if self._X is None:
            self._X = np.empty(self.dbi.shape)
            for i in idx_iter:
                self._X[i] = _get_s((self.mv[i], self.pv[i]))
            self._X += self.dbi
        return self._X

    @property
    def X(self):
        return self._x(self.var_idx_iter())
