"""autoencoders"""

import logging

import numpy as np

import theano
import theano.tensor as T
from theano.tensor.nnet import sigmoid

from . import Model, SpeedLayer
from .. import archive
from config_spec import AESpec

log = logging.getLogger(__name__)


class AELayer( SpeedLayer ):
    """A denoising autoencoder layer"""

    def __init__(self, X, n_in, n_hidden, rng, thrng, dtype, clevel=0):
        """initialize an AE instance

        :param tensor X: input (theano.tensor)
        :param int n_in: input dimension (int)
        :param int n_hidden: hiddden dimension (int)
        :param RandomState rng: random state (np.random.RandomState)
        :param thrng: random state (theano.tensor.shared_randomstreams.RandomStreams)
        :param np.dtype dtype: dtype of weights (np.dtype)
        :param clevel: keeps ``1-corruption_level`` entries of the inputs the same
                 and zero-out randomly selected subset of size ``coruption_level``
        """

        super(AELayer, self).__init__([(n_in, n_hidden), (n_hidden,), (n_in,)],
                                      ["W", "b", "b_prime"],
                                      [0.0, 0.0, 0.0],
                                      [dtype, dtype, dtype])

        if type(rng) == np.random.RandomState:
            log.debug("AEW: rnd")
            thr = np.sqrt( 6. / (n_in + n_hidden) )
            self._weights_[0].set_value(np.asarray(rng.uniform(low=-thr, high=thr,
                                                               size=(n_in, n_hidden)),
                                                   dtype=dtype) )

        self.theano_rng = thrng
        self.input = X
        self.corruption_level = clevel
        self.tilde_input = self.corrupt( self.input, clevel )

        w, b, bp = self.get_params()
        self.encoder = sigmoid( T.dot(self.tilde_input, w) + b )
        self.decoder = sigmoid( T.dot(self.encoder, w.T) + bp )

    def activation(self):
        return self.encoder

    def corrupt(self, X, corrupt_level):
        """theano function that adds binomial noise to the given input

        first argument of theano.rng.binomial is the shape(size) of
        random numbers that it should produce
        second argument is the number of trials
        third argument is the probability of success of any trial
        this will produce an array of 0s and 1s where 1 has a probability of
        1 - ``corruption_level`` and 0 with ``corruption_level``
        """

        return T.cast(self.theano_rng.binomial(
                      size=X.shape, n=1, p=1 - corrupt_level) * X, X.dtype)

    def cost(self, l1, l2):
        l1_term = l1 * self.weight_norm("l1")
        l2_term = l2 * self.weight_norm("l2")

        return self.reconstruction_cost() + l1_term + l2_term

    def reconstruction_cost(self, cost_type="cross-entropy"):
        z = self.decoder
        x = self.input
        if cost_type == "cross-entropy":
            # note : we sum over the size of a datapoint; if we are using minibatches,
            #        L will  be a vector, with one entry per example in minibatch
            L = -T.sum(x * T.log(z) + (1 - x) * T.log(1 - z), axis=1 )
            # note : L is now a vector, where each element is the cross-entropy cost
            #        of the reconstruction of the corresponding example of the
            #        minibatch. We need to compute the average of all these to get
            #        the cost of the minibatch
        elif cost_type == "L2":
            L = (x - z)**2
        return T.mean(L)

    def __str__(self):
        return "[(AE) %dx%d # %.2f]" % (self.get_weights()[0].shape + (self.corruption_level,))


class AutoEncoder( Model ):
    """Denoising autoencoder

    A denoising autoencoders tries to reconstruct the input from a corrupted
    version of it by projecting it first in a latent space and reprojecting
    it afterwards back in the input space. Please refer to Vincent et al.,2008
    for more details. If x is the input then equation (1) computes a partially
    destroyed version of x by means of a stochastic mapping q_D. Equation (2)
    computes the projection of the input into the latent space. Equation (3)
    computes the reconstruction of the input, while equation (4) computes the
    reconstruction error.

    .. math::

       \tilde{x} ~ q_D(\tilde{x}|x)                                     (1)

       y = s(W \tilde{x} + b)                                           (2)

       x = s(W' y  + b')                                                (3)

       L(x,z) = -sum_{k=1}^d [x_k \log z_k + (1-x_k) \log( 1-z_k)]      (4)

    """

    def __init__(self, ins, hs, rng, theano_rng, wdtype, corruption_level):
        """ Initialize the dA class by specifying the number of visible units (the
        dimension d of the input ), the number of hidden units ( the dimension
        d' of the latent or hidden space ) and the corruption level.

        :param ins: input dimension (int)
        :param hs: hidden dimension (int)
        :param rng: random state (np.random.RandomState)
        :param theano_rng: random state (theano.tensor.shared_randomstreams.RandomStreams)
        :param wdtype: dtype of weights (np.dtype)
        :param corruption_level: keeps ``1-corruption_level`` entries of the
                inputs the same and zero-out randomly selected subset of size
                ``coruption_level``

        """

        X = T.matrix("X", dtype=wdtype)
        super(AutoEncoder, self).__init__( [AELayer(X, ins, hs, rng, theano_rng,
                                                    wdtype, corruption_level)] )

    def cost(self, l1, l2):
        return self[0].cost(l1, l2)

    def update_params(self, train_batches, gradient_f, momentum, lrate):
        """step on the direction of the

        step on the direction of the gradient for
        a whole epoch and update the model params in place.
        By definition speed is initialized to 0.
        new_speed = -rho * dE/dw + mu * speed
        new_weight = w + new_speed

        :param train_batches: indexes of batches (list)
        :param gradient_f: function that returns the list of gradients from the
            batch index.
        :param momentum: mu
        :param lrate: rho
        :rtype: none
        """

        for batch_i in train_batches:
            all_grads = gradient_f(batch_i)
            for layer in self:
                l_grads = map(lambda i: all_grads.pop(0),
                              range(len(layer.get_params())))

                layer.speed_update(l_grads, momentum, lrate)
                layer.weight_update()


class AEStack(Model):
    """a stack of denoising autoencoders. Each layer is a denoising autoencoder.
    A denoising autoencoders tries to reconstruct the input from a corrupted
    version of it by projecting it first in a latent space and reprojecting
    it afterwards back in the input space. Please refer to Vincent et al.,2008
    for more details. If x is the input then equation (1) computes a partially
    destroyed version of x by means of a stochastic mapping q_D. Equation (2)
    computes the projection of the input into the latent space. Equation (3)
    computes the reconstruction of the input, while equation (4) computes the
    reconstruction error.

    .. math::

       \tilde{x} ~ q_D(\tilde{x}|x)                                     (1)

       y = s(W \tilde{x} + b)                                           (2)

       x = s(W' y  + b')                                                (3)

       L(x,z) = -sum_{k=1}^d [x_k \log z_k + (1-x_k) \log( 1-z_k)]      (4)

    this is a completely unsupervised model that concatenates
    autoencoders.
    """

    def __init__(self, ins, hs_lst, rng, theano_rng, wdtype, corruption_level):
        """initialize a stack of autoencoders


        :param ins: input dimension (int)
        :param list hs_lst: hidden dimension list of ints
        :param rng: random state (np.random.RandomState)
        :param theano_rng: random state (theano.tensor.shared_randomstreams.RandomStreams)
        :param wdtype: dtype of weights (np.dtype)
        :param corruption_level: keeps ``1-corruption_level`` entries of the
                         inputs the same and zero-out randomly selected subset
                         of size ``coruption_level``
        """

        X = T.matrix("X", dtype=wdtype)
        layers = []
        for hs in hs_lst:
            ael = AELayer(X, ins, hs, rng, theano_rng,
                          wdtype, corruption_level)
            layers.append( ael )
            X = layers[-1].encoder
            ins = hs
        super(AEStack, self).__init__(layers)
        self.ish = map(lambda l: l.get_weights()[0].shape[0], self)
        self.osh = map(lambda l: l.get_weights()[0].shape[1], self)
        self.__enc_f = {}
        self.__dec_f = {}
        self.mspec = AESpec._make((self.ish[0], tuple(hs_lst), corruption_level))

    def compute_repr(self, i, x, hidden=True):
        """compute the representation of x by the i-th layer.

        :param int i: the layer index [0, ..., len(self)-1]
        :param ndarray x: input. must conform
                          (i.e., reshap-able) to self.ish[i]
        :param bool hidden: whether to compute the hidden state
        :rtype: ndarray of shape self.osh[i]
        """

        if hidden:
            thf = theano.function(inputs=[self[i].input],
                                  outputs=self[i].encoder)
            f = self.__enc_f.setdefault(i, thf)
        else:
            thf = theano.function(inputs=[self[i].input],
                                  outputs=self[i].decoder)
            f = self.__dec_f.setdefault(i, thf)
        return f(x)

    def compute_state(self, i, x, hidden=True):
        """compute the representation of `x` up to the i-th layer

        x is considered as input to the model and is propagated
        throught the first i layers

        :param int i: the layer index [0, ..., len(self)-1]
        :param ndarray x: input to the model. must conform
            (i.e., reshap-able) to self.ish[0]
        :param bool hidden: whether to compute the hidden state
        :rtype: ndarray"""

        # compute hidden states up to i-1th layer
        for l in range(i):
            f = self.__enc_f.setdefault(l, theano.function(inputs=[self[l].input],
                                                           outputs=self[l].encoder))
            x = f(x)

        # compute state for the last layer
        if hidden:
            f = self.__enc_f.setdefault(i, theano.function(inputs=[self[i].input],
                                                           outputs=self[i].encoder))
        else:
            f = self.__dec_f.setdefault(i, theano.function(inputs=[self[i].input],
                                                           outputs=self[i].decoder))
        return f(x)

    def update_params(self, train_batches, gradient_f, momentum, lrate, lidx):
        """step on the direction of the gradient

        step on the direction of the gradient for
        a whole epoch and update the model params in place.
        By definition speed is initialized to 0.
        new_speed = -rho * dE/dw + mu * speed
        new_weight = w + new_speed

        :param train_batches: indexes of batches (list)
        :param gradient_f: function that returns the list of
                            gradients from the batch index.
        :param momentum: mu
        :param lrate: rho
        :rtype: None
        """

        for batch_i in train_batches:
            all_grads = gradient_f(batch_i)
            layer = self[lidx]
            l_grads = map(lambda i: all_grads.pop(0),
                          range(len(layer.get_params())))

            layer.speed_update(l_grads, momentum, lrate)
            layer.weight_update()

    @classmethod
    def _from_archive(cls, arch, path, rng, thrng, dtype, zero_clevel=True):
        """instantiate a model and load archived weights

        :param str arch: path to the HDF archive
        :param str path: path (inside the archive) of the training directory
            (usually: dataset/trid)
        :param numpy.RandomState rng: random generator
        :param theano_rng: random state (theano.tensor.shared_randomstreams.RandomStreams)
        :param numpy.dtype dtype: data type of weights (same as data datatype)
        :param bool zero_clevel: do not corrupt input. handy
            when you call this not to train the model
        :rtype: :py:class:`dimer.nnet.autoencoder.AEStack`
        """

        ms = AESpec._from_archive(arch, path)
        model = AEStack(ms.ins, ms.hiddens, rng, thrng, dtype,
                        (0 if zero_clevel else ms.clevel))
        model.load(archive.join(arch, path))
        return model
