"""module for inverting the computation of a neural network
using LPs and MIPs"""


import numpy as np
from itertools import product


def draw_array(x, thr):
    """'draw' a 1-D  array as a series of dots and spaces into a string.
    the string will represent a binarized version of the vector values, with
    "." if x >= `thr` else " "

    :param ndarray x: a 1d array
    :param float thr: thresholding value
    :rtype: a string that represents the values of the array
    """

    assert len(x.shape) == 1

    to_char = lambda (l, v): (v >= l and "." or " ")

    txt = []
    for lb in thr[::-1]:
        row_str = map(to_char, zip([lb] * x.shape[0], x))
        txt.append( "".join( row_str ) )
    return txt


def tup_prod(sh, flip=False):
    """handy to produce a Cartezian product of sets {0, ... sh[0]-1} X {0, ... sh[1]-1}

    :param sh: tuple of length 2
    :param flip: reverse the sets before computing pairs {sh[0]-1 down to 0} X {sh[1]-1 down to 0}
    :return: iterator over pairs"""

    assert len(sh) == 2
    m, n = sh
    if flip:
        return product( range(m - 1, -1, -1), range(n - 1, -1, -1) )
    return product( *map(range, sh) )

def unpool(x, p):
    """un pool the array by expanding its last len(p) dimensions

    you can make this generic, by implemeting a recursive function?

    :param ndarray x: array of 1, 2 or 3 dimensions
    :param tuple p: pool size (tuple of length 2)
    :return: array of dimensions x0, x1, x2, ..., xn-1*p[0], xn*p[1]"""

    assert len(p) == 2
    if len(x.shape) == 1:
        x = x.reshape( 1, -1 )
    assert len(x.shape) > 1
    z = np.prod( p )

    if len(x.shape) == 2:
        x_ = np.zeros( (x.shape[0] * p[0], x.shape[1] * p[1]) )
        for r in range(x.shape[0]):
            for c in range(x.shape[1]):
                x_[r * p[0]:(r + 1) * p[0], c * p[1]:(c + 1) * p[1]] = x[r, c]
    elif len(x.shape) == 3:
        x_ = np.zeros( (x.shape[0], x.shape[1] * p[0], x.shape[2] * p[1]) )
        for d0 in range(x.shape[0]):
            for r in range(x.shape[1]):
                for c in range(x.shape[2]):
                    x_[d0, r * p[0]:(r + 1) * p[0], c * p[1]:(c + 1) * p[1]] = x[d0, r, c]
    elif len(x.shape) == 4 and x.shape[0] == 1:
        upx = unpool(x[0], p)
        return upx.reshape( (1,) + upx.shape )
    else:
        raise ValueError("cannot deal with this input shape %s" % str(x.shape))
    return x_ / z


def constraint_boundary(s, t, x, lb, ub):
    """Given state, decision boundary, and x arrays compute a constraint boundary Y

    the idea is to use \|x-T\| as an indication of the importance of the feature.
    So for big \|X-T\| features, the constraint boundary will equate the state,
    and for \|X-T\| = 0 features, the constrain boundary will equate the decision boundary.

    Notice if a constraint in the previous layer was relaxed, X and S might be on
    opposite sides of the decision boundary. In this case we set \|X-T\| = 0 so that
    the constraint boundary is completely relaxed (i.e., = decision boundary).

    :param S: the state of the model for a certain label. this is the representation
             of the average signal over training examples of a given label
    :param T: the threshold signal indicating the decision bounday for the labels. this is
             equal to the average of all training examples (regardless of label)
    :param X: current estimate (from the LP) of the state of the network.
    :param lb: lower bound for X, will use max(X) if None
    :param ub: upper bound for X, will use min(X) if None
    """

    assert s.shape == t.shape == x.shape, "%s!=%s!=%s" % (str(s.shape),
                                                          str(t.shape),
                                                          str(x.shape))

    assert lb < np.min(t) and ub > np.min(t)

    ## normalizer for x
    M = np.ones( x.shape )
    M[ x > t ] = (ub - t)[ x > t ]
    M[ x < t ] = (t - lb)[ x < t ]
    assert np.all( M > 0 )

    ## weigh s values by x
    cb_tilde = (s - t) * ((x - t) / M)
    ## if s and x are on opposite sides of t,
    ## set CB = t
    cb_tilde[ cb_tilde < 0 ] = 0
    assert np.all( cb_tilde >= 0 )

    return (cb_tilde * np.sign(s - t)) + t

def inv_sigmoid( x ):
    """inverse of the sigmoid function

    :param ndarray x: an array
    :rtype: an array of the same shape as x with inverse-sigmoid values
    """

    return np.log( x / (1 - x) )

def sigmoid( x ):
    """sigmoid function

    :param ndarray x: an array
    :rtype: an array of the same shape as x with sigmoid values
    """

    return 1 / ( 1 + np.exp(-x) )

def softmax( x ):
    """softmax function

    :param ndarray x: an array
    :rtype: an array of the same shape as x with softmax values
    """

    assert len(x.shape) == 1 or x.shape[0] == 1
    z = np.exp(x).sum()
    return np.exp( x ) / z

def inv_softmax(x, sum_of_exp):
    """pseudo-inverse softmax function

    :param ndarray x: an array
    :rtype: an array of the same shape as x with pseudo-inverse softmax values
    """

    assert len(x.shape) == 1 or x.shape[0] == 1
    return np.log(sum_of_exp * x)

