"""convolutional neural nets"""

import logging
import os

import dimer
import numpy as np
import theano
import theano.tensor as T
from dimer import archive
from dimer.nnet import Model, SpeedLayer
from dimer.nnet.config_spec import CnnModelSpec, MlpModelSpec
from theano.tensor.nnet import conv, sigmoid, softmax
from theano.tensor.signal import downsample

log = logging.getLogger(__name__)


#TODO: rename to LinearReg
class LinearRegression(SpeedLayer):
    "A linear regression layer"

    def __init__(self, X, n_in, n_out, rng, dtype):
        """A linear regression layer that computes np.dot(X, W)+b

        weights are randomly initialized, and the bias to zero

        :param tensor X: input X (symbolic variable)
        :param int n_in: input dimension
        :param int n_out: output dimension. must be 1
        :param RandomState rng: random stream
        :param np.dtype dtype: data type of W"""

        if n_out != 1:
            raise ValueError("nout (%s) must be a scalar" % str(n_out))
        SpeedLayer.__init__(self, [(n_in, n_out), (n_out,)],
                            ["lrW", "lrb"], [rng, 0.0],
                            [dtype, dtype])
        self.input = X

    def activation(self):
        "dot product of input with weights + bias (samples are rows)"

        W, b = self.get_params()
        return sigmoid(T.dot(self.input, W) + b)

    def __str__(self):
        return "[(LIN_REG) %dx%d]" % self.get_weights()[0].shape


class LogisticReg(SpeedLayer):
    """A logistic regression layer"""

    def __init__(self, X, n_in, n_out, rng, dtype):
        """A logistic regression layer that computes np.dot(X, W)+b

        weights are randomly initialized, and the bias to zero

        :param tensor X: input X (symbolic variable)
        :param int n_in: input dimension
        :param int n_out: output dimension. typically 1
        :param RandomState rng: random stream
        :param np.dtype dtype: data type of W"""

        super(LogisticReg, self).__init__([(n_in, n_out), (n_out,)],
                                          ["llW", "llb"], [rng, 0.0],
                                          [dtype, dtype])

        self.input = X
        # class prediction vector (n_out x 1)
        self.p_y_given_x = self.activation()

        # the index of the largest value is the chosen class
        self.y_hat = T.argmax(self.p_y_given_x, axis=1)

    def activation(self):
        "softmax activation. XW + b (samples are rows)"

        W, b = self.get_params()
        return softmax(T.dot(self.input, W) + b)

    def __str__(self):
        return "[(LOG_REG) %dx%d]" % self.get_weights()[0].shape


class HiddenLayer(SpeedLayer):
    """Hidden layer of a feed-forward net """

    def __init__(self, X, n_in, n_out, rng, dtype):
        """A logistic regression layer that computes np.dot(X, W)+b

        weights are randomly initialized, and the bias to zero

        :param tensor X: input X (symbolic variable)
        :param int n_in: input dimension
        :param int n_out: output dimension. typically 1
        :param RandomState rng: random stream
        :param np.dtype dtype: data type of W"""

        super(HiddenLayer, self).__init__([(n_in, n_out), (n_out,)],
                                          ["hlW", "hlb"], [0.0, 0.0],
                                          [dtype, dtype])

        if type(rng) == np.random.RandomState:
            thr = np.sqrt(6. / (n_in + n_out))
            self._weights_[0].set_value(
                np.asarray(rng.uniform(low=-thr, high=thr, size=(n_in, n_out)),
                           dtype=dtype))

        self.input = X

    def activation(self):
        "softmax activation. XW + b (X is a row)"

        W, b = self.get_params()
        return sigmoid(T.dot(self.input, W) + b)

    def __str__(self):
        return "[(HIDDEN) %dx%d]" % self.get_weights()[0].shape


class ConvPoolLayer(SpeedLayer):
    """LeNet  conv-pool layer"""

    def __init__(self, X, fshape, ishape, rng, poolsize, dtype):
        """Le Cun convolutional layer

        fshape: (# filters, # in_feature_maps, width, height)
        ishape: (batch_size, # feature_maps, width, height)
        """
        assert fshape[1] == ishape[1], "nr. of feature maps should not change"

        super(ConvPoolLayer, self).__init__([fshape, (fshape[0],)],
                                            ["cpW", "cpb"], [rng, 0.0],
                                            [dtype, dtype])
        if type(rng) == np.random.RandomState:

            thr = np.sqrt(3. / np.prod(fshape[1:]))
            self._weights_[0].set_value(
                np.asarray(rng.uniform(low=-thr, high=thr, size=fshape),
                           dtype=dtype))

        self.input = X
        self.ishape = ishape
        self.fshape = fshape
        self.pshape = poolsize

    def activation(self):
        """activation function"""

        W, b = self.get_params()
        conved = conv.conv2d(self.input, W,
                             filter_shape=self.fshape,
                             image_shape=self.ishape)
        pooled = downsample.max_pool_2d(conved, self.pshape,
                                        ignore_border=True)
        return sigmoid(pooled + b.dimshuffle('x', 0, 'x', 'x'))

    def __str__(self):
        """in_fature_maps -> nr_of_kern (receprive_field_size (wXh) /
        pool_size(wXh)) ->"""

        (nk, ifm, fw, fh) = self.fshape
        (pw, ph) = self.pshape
        weights = "[(CONV_POOL) %d -> @%d (%dx%d) / %dx%d]" % (ifm, nk,
                                                               fw, fh,
                                                               pw, ph)
        fsh = (self.fshape[0], (self.ishape[2] - self.fshape[2] + 1) / self.pshape[0],
               (self.ishape[3] - self.fshape[3] + 1) / self.pshape[1])
        state = "[%d/batch  @%d (%dx%d) -> @%d (%dx%d)]" % (self.ishape + fsh)
        return weights + "  " + state


class CnnModel(Model):
    """A convolutional network model, with 0 or more ConvPool layers
    followed by a hidden layer and a logistic regression or linear
    regression layer"""

    def __init__(self, arch, lreg_size, inshape, nout, rng, xdtype, ydtype):
        """instantiate a CnnModel instance

        :param arch: a ConvPool spec list of the type
            [(nkern, (rf_row, rf_col), (psh_row, psh_col) ...]
        :param in tlreg_size: input size of the linear/logistic regression layer
        :param (batch_size, fmaps, rows, cols) inshape: input shape
            (batch_size, fm, rows, cols)
        :param in tnout: output dimension
        :param RandomState rng: random state
        :param np.dtype/str xdtype: data type of input
        :param str ydtype: data type of output
        """

        if not arch:
            raise ValueError("empty arch. need at least one ConvPool layer")
        if not xdtype in (np.float64, np.float32, 'float64', 'float32'):
            raise ValueError("bad xdtype (%s). must be a float32/64" % str(xdtype))

        xdtype = {np.float64: "float64",
                  np.float32: "float32"}.get(xdtype, xdtype)
        self.X = T.tensor4('X', dtype=xdtype)
        self.Y = T.vector("Y", dtype=ydtype)

        self.ish = {}
        self.osh = {}

        self.ish[0] = (in_bs, in_fm, in_w, in_h) = inshape
        layers = []
        this_input = self.X.reshape(inshape)
        #img_sh = inshape
        for i, (nkern, rf, ps) in enumerate(zip(*arch)):
            if i:
                self.ish[i] = self.osh[i - 1]
            layers.append(ConvPoolLayer(this_input,
                                        (nkern, self.ish[i][1], rf[0], rf[1]),
                                        self.ish[i], rng, ps, xdtype))
            this_input = layers[-1].activation()
            self.osh[i] = (in_bs, nkern, (self.ish[i][2] - rf[0] + 1) / ps[0],
                           (self.ish[i][3] - rf[1] + 1) / ps[1])
            #img_sh = (in_bs, nkern, (img_sh[2] - rf[0] + 1) / ps[0],
            #           (img_sh[3] - rf[1] + 1) / ps[1])

        # add a mlp
        self.ish[i + 1] = (in_bs, nkern * self.osh[i][2] * self.osh[i][3])
        self.osh[i + 1] = (in_bs, lreg_size)
        layers.append(HiddenLayer(this_input.flatten(2),
                                  self.ish[i + 1][1],
                                  self.osh[i + 1][1], rng, xdtype))

        self.ish[i + 2] = (in_bs, lreg_size)
        self.osh[i + 2] = (in_bs, nout)
        if str(ydtype).startswith('int'):
            _topl_cls = LogisticReg
        elif str(ydtype).startswith('float'):
            _topl_cls = LinearRegression
        else:
            raise ValueError("cannot understand ydtype: %s" % str(ydtype))
        layers.append(_topl_cls(layers[-1].activation(),
                                self.ish[i + 2][1], self.osh[i + 2][1],
                                rng, xdtype))
        Model.__init__(self, layers)
        self.__act_f = {}
        self.mspec = CnnModelSpec._make(tuple(arch) + (lreg_size,))

    def compute_repr(self, i, x):
        """compute the representation of x by the i-th layer.

        :param int i: the layer index [0, ..., len(self)-1]
        :param ndarray x: input. must conform
                          (i.e., reshap-able) to self.ish[i]
        :rtype: ndarray of shape self.osh[i]
        """

        f = self.__act_f.setdefault(i, theano.function(inputs=[self[i].input],
                                                       outputs=self[i].activation()))
        return f(x.reshape(self.ish[i]))

    def compute_state(self, i, x):
        """compute the representation of `x` up to the i-th layer

        x is considered as input to the model and is propagated
        throught the first i layers

        :param int i: the layer index [0, ..., len(self)-1]
        :param ndarray x: input to the model. must conform
            (i.e., reshap-able) to self.ish[0]
        :rtype: ndarray"""

        for l in range(i + 1):
            f = self.__act_f.setdefault(l, theano.function(inputs=[self[l].input],
                                                           outputs=self[l].activation()))
            x = f(x.reshape(self.ish[l]))
        return x

    @property
    def top_cp_idx(self):
        "idx of the top convpool layer"

        return len(self) - 3

    def get_speeds(self):
        """
        :rtype: map(lambda l: l.get_speeds(), self)
        """

        return map(lambda l: l.get_speeds(), self)

    def set_speeds(self, vlst):
        "set speeds of all layers"
        for w, i in enumerate(vlst):
            self[i].set_speeds(w)

    def cost(self, l1, l2):
        """regularized cross entropy (classification) or mean squeared (regression)

        :param float l1: L1 coefficient (float)
        :param float l2: L2 coefficient (float)
        :rtype: cost function"""

        l1_term = l1 * self.weight_norm("l1")
        l2_term = l2 * self.weight_norm("l2")
        toplayer = self[-1]
        tltype = type(toplayer)

        row_idx = T.arange(self.Y.shape[0])
        if tltype == LogisticReg:
            error = T.log(toplayer.p_y_given_x)[row_idx, self.Y]
            return -T.mean(error) + l1_term + l2_term
        elif tltype == LinearRegression:
            # activation size is <batch_size, 1>
            # self.Y size is <batch_size>, so I make this broadcastable
            # on the second dimension
            error = (toplayer.activation() - self.Y.dimshuffle(0, 'x')) ** 2
            return T.mean(error) + l1_term + l2_term
        else:
            raise ValueError("cannot compute cost for toplayer: %s" % tltype)

    def update_params(self, train_batches, gradient_f, momentum, lrate):
        """step on the direction of gradient

        step on the direction of gradient
        for a whole epoch and update the model params in place.
        By definition speed is initialized to 0.
        new_speed = -rho * dE/dw + mu * speed
        new_weight = w + new_speed

        :param train_batches: indexes of batches (list)
        :param gradient_f: function that returns the list of gradients
            from the batch index.
        :param momentum: mu
        :param lrate: rho
        :rtype: None
        """

        for batch_i in train_batches:
            all_grads = gradient_f(batch_i)
            for layer in self:
                l_grads = map(lambda i: all_grads.pop(0), range(len(layer.get_params())))

                layer.speed_update(l_grads, momentum, lrate)
                layer.weight_update()
                #new_speeds = map(speed_update_f, zip(layer.get_speeds(), l_grads))
                #layer.set_speeds(new_speeds)
                #
                #new_weights = map(weight_update_f,
                #        zip(layer.get_speeds(), layer.get_weights()))
                #layer.set_weights(new_weights)

    @classmethod
    def _from_archive(cls, arch, path, rng, ds, raw_data=False, batch_size=1):
        """instantiate a model and load archived weights and a dataset instance

        :param str arch: path to the HDF archive
        :param str path: path (inside the archive) of the training directory
            (usually: dataset/trid)
        :param numpy.RandomState rng: random generator
        :param numpy.dtype dtype: data type of weights (same as data datatype)
        :rtype: :py:class:`dimer.nnet.nccn.CnnModel`
        """

        ms = CnnModelSpec._from_archive(arch, path)
        model = CnnModel(ms.cp_arch, ms.lreg_size,
                         (batch_size, 1, ds.tracks, ds.width),
                         ds.labels, rng,
                         xdtype=ds.X.dtype, ydtype='int32')
        model.load(archive.join(arch, path))
        return model


class MlpModel(Model):
    def __init__(self, nin, hsizes, nout, xdtype, ydtype, rng):
        self.ish = zip((1,) * (1 + len(hsizes)), (nin,) + tuple(hsizes))
        self.osh = zip((1,) * (1 + len(hsizes)), tuple(hsizes) + (nout,))

        self.X = T.matrix("X", dtype=xdtype)
        self.Y = T.vector("Y", dtype=ydtype)
        layers = []
        this_input = self.X
        for ins, outs in zip(self.ish[:-1], self.osh[:-1]):
            layers.append(HiddenLayer(this_input, ins[1], outs[1], rng, xdtype))
            this_input = layers[-1].activation()

        if str(ydtype).startswith('int'):
            _top_cls = LogisticReg
        elif str(ydtype).startswith('float'):
            _top_cls = LinearRegression
        else:
            raise ValueError("cannot understand ydtype: %s" % str(ydtype))
        layers.append(_top_cls(this_input, self.ish[-1][1], self.osh[-1][1], rng, xdtype))
        Model.__init__(self, layers)
        self.__act_f = {}
        self.mspec = MlpModelSpec._make((hsizes,))

    def compute_repr(self, i, x):
        f = self.__act_f.setdefault(i, theano.function(inputs=[self[i].input],
                                                       outputs=self[i].activation()))
        return f(x.reshape(self.ish[i]))

    def compute_state(self, i, x):
        for l in range(i + 1):
            f = self.__act_f.setdefault(l, theano.function(inputs=[self[l].input],
                                                           outputs=self[l].activation()))
            x = f(x.reshape(self.ish[l]))
        return x

    def get_speeds(self):
        return map(lambda l: l.get_speeds(), self)

    def set_speed(self, vlst):
        for w, i in enumerate(vlst):
            self[i].set_speeds(w)

    def cost(self, l1, l2):
        l1_term = l1 * self.weight_norm("l1")
        l2_term = l2 * self.weight_norm("l2")
        toplayer = self[-1]
        tltype = type(toplayer)

        row_idx = T.arange(self.Y.shape[0])
        if tltype == LogisticReg:
            error = T.log(toplayer.p_y_given_x)[row_idx, self.Y]
            return -T.mean(error) + l1_term + l2_term
        elif tltype == LinearRegression:
            error = (toplayer.activation() - self.Y.dimshuffle(0, 'x')) ** 2
            return T.mean(error) + l1_term + l2_term
        else:
            raise ValueError("cannot computer cost for toplayer: %s" % tltype)

    def update_params(self, train_batches, gradient_f, momentum, lrate):
        for batch_i in train_batches:
            all_grads = gradient_f(batch_i)
            for layer in self:
                l_grads = map(lambda i: all_grads.pop(0),
                              range(len(layer.get_params())))
                layer.speed_update(l_grads, momentum, lrate)
                layer.weight_update()
