'''
Support for dataset IO and manipulation
'''

import logging
from collections import Counter
import numpy as np
import pandas as pd
import theano

from dimer import archive
from dimer import ops


log = logging.getLogger(__name__)


class Dataset(object):
    """a bare-to-bones dataset with an predictors,
    and quantitative and qualitative outputs"""

    def __init__(self, X, Y, T):
        """instantiate the class

        :param ndarray X: predictors <nr of samples X data dimension
                        X data dimension>
        :param ndarray Y: output <nr of samples X 1> or one dimensional
        :param ndarray T: class output <nr of samples X 1 > or one dimensional"""

        self.X, self.Y, self.T = X, Y, T

        if (not (Y is None)) and X.shape[0] != Y.shape[0]:
            raise ValueError("|X| (%d) != |Y| (%d)" % (X.shape[0],
                                                       Y.shape[0]))
        if (not (T is None)) and X.shape[0] != T.shape[0]:
            raise ValueError("|X| (%d) != |T| (%d)" % (X.shape[0],
                                                       T.shape[0]))
        log.info("allocated dataset. X of shape %s, Y %s, T %s",
                 str(self.X.shape),
                 (self.Y is None and "missing" or str(self.Y.shape)),
                 (self.T is None and "missing" or str(self.T.shape)))

    def target_information(self, T=None):
        """compute the information on the targets of this dataset

        the formula is sum_i (si/s) log(si / s), where
        si is the count of examples in class i (i.e., with target value = i)
        i runs over targets

        :rtype: a float indicating the infrmation of targets"""

        if T is None:
            T = self.T
        if T is None:
            raise ValueError("targets missing for this dataset")

        i = 0
        counts = Counter(T)
        s = float(sum(counts.values()))
        for cnt in counts.values():
            r = (cnt / s)
            i -= (r * np.log2(r))
        return i

    def attr_entropy(self, attr_idx, nbins):
        """compute the entropy of the attribute with the given attribute index

        :param int attr_idx: index of the attribute to consider. If a sample is
                  2D, it will be flattened before.
        :param int nbins: the number of bins to discretize the attribute
        :rtype: a float indicating the infrmation of targets"""

        def discretize_x(X):
            low, high = np.min(X), np.max(X)
            if nbins != 2:
                raise ValueError("only 2 bins supported so far")
            return np.digitize(X, [low, 0, high])

        T = self.T
        if len(self.X.shape) == 2:
            X = self.X[:, attr_idx]
        else:
            q = attr_idx / self.X.shape[-1]
            r = attr_idx % self.X.shape[-1]
            X = self.X[:, q, r]
        if not (nbins is None):
            X = discretize_x(X)

        entropy = 0.0
        s = float(T.shape[0])
        for a in set(X):
            weight = self.target_information(T[X == a])
            entropy += ((T[X==a].shape[0] / s) * weight)
        return entropy

    def flatten(self):
        """flatten X of this dataset.

        if samples are 2d and you are not using a bottom layer that
        handles such an input, you need to flatten it before usage.
        """

        self.X = self.X.reshape((self.X.shape[0], -1))

    @property
    def is_labeled(self):
        return not (self.T is None)

    @property
    def labels(self):
        if self.is_labeled:
            return np.unique( self.T ).shape[0]
        else:
            raise AttributeError("unlabeled dataset")

    @staticmethod
    def normalize_features(x):
        """transform each component of flattened X examples to 0 mean and 1 std
        So the values of feature f (from all examples) are 0 mean and 1 std

        :param numpy.ndarray x: a ndarray of shape
                                (nr. examples, nr. of features)
        :rtype: (the shifted input, the mean for each input
                component, the sd of each
                input component) the latter 2 are arrays
                of shape(<tracks>, <genome position>)
        """

        return ops.standardize(x, axis=0)

    @staticmethod
    def fit_features(x):
        """transform each **component** of X so
        that it fits on an interval [-1, 1].
        So the values of track t at position i are all in [-1,1]

        :param x: a ndarray of shape (nr. examples, nr. features)
        :rtype: the fitted input
        """

        return ops.fit(x, axis=0)

#    @staticmethod
#    def state_path(i, dir=None):
#        """path to a state
#
#        :param str dir: path to the training instance. (e.g., data/train)
#                    if None a relative path will be returned
#        :param int i: index of the layer
#        :rtype: str"""
#
#        p = os.path.join("layer_%d" % i, "state")
#        if not (dir is None):
#            p = os.path.join(dir, p)
#        return p


class TheanoShare( object ):
    """a dataset that can return its data as theano shared variables"""

    def __init__(self):
        self.__shX, self.__shY, self.__shT = None, None, None

    def __sh_anon(self, what, shape=None, borrow=True):
        if getattr(self, what) is None:
            raise ValueError("cannot share non-existent member %s" % what)

        if getattr(self, "_TheanoShare__sh%s" % what) is None:
            init_val = getattr(self, what)
            if shape:
                init_val = init_val.reshape( shape )

            setattr(self, "_TheanoShare__sh%s" % what,
                    theano.shared(init_val, borrow=borrow))
        return getattr(self, "_TheanoShare__sh%s" % what)

    @property
    def shX(self):
        return self.__sh_anon("X")

    @property
    def shY(self):
        return self.__sh_anon("Y")

    @property
    def shT(self):
        return self.__sh_anon("T")

    def share(self, which, shape=None, borrow=True):
        """wrap the data on a thean.shared variable

        :param which: what component to wrap (str, typically 'X', 'T', 'Y')
        :param shape: reshape the array to this shape
        :param borrow: passer to theano.share
        :rtype: theano.shared instance initialized to the required data"""

        val = getattr(self, which)
        if not (shape is None):
            val = val.reshape( shape )
        return theano.shared(val, borrow=borrow)


class TrainDataset( object ):
    """a mixin for batch functionality, valid and train sub-dataset"""

    def __init__(self, batch_s, tot_s=None, valid_s=None, valid_idx=None, rng=None):
        """Dataset that will
        create train and validation batches from the given params.

        the idea is to split the data into batches and allocate a 'valid_s'
        portion of them for validation. the position of the (continuous) validation
        block is w.r.t batches. E.g., for tot_size = 10, batch_s = 2,
        valid_idx=3, valid_s = 0.3 you get 4 + 1 train + valid batches: T T T V T

        :param tot_s: nr. of examples
        :param batch_s : batch size
        :param valid_s : fraction of data to allocate for validation
        :param valid_idx: batch index at which allocate validation data
        :param rng : numpy.RandomState used to shuffle batches or None (no shuffle)
        :rtype : (train_batches, valid_batches)"""

        if tot_s is None:
            tot_s = self.X.shape[0]
        if self.X.shape[0] < tot_s:
            log.warning("total size (%d) > dataset size (%d). adjusting ...",
                        tot_s, self.X.shape[0])
        self.total_size = min( tot_s, self.X.shape[0] )

        self.batch_size = batch_s
        if valid_s is None:
            valid_s = 0.25
        self.valid_size = valid_s

        self.n_batches = self.X.shape[0] / self.batch_size  # nr. of batches
        if valid_idx is None:
            valid_idx = self.n_batches - int(self.n_batches * self.valid_size)
        self.valid_idx = valid_idx

        self.rng = rng

        self.train_batches, self.valid_batches = self.__batches()

    def __batches(self):
        tot_size = self.total_size
        batch_s = self.batch_size
        valid_s = self.valid_size
        valid_idx = self.valid_idx
        rng = self.rng

        if valid_s <= 0 or valid_s >= 1:
            raise ValueError("valid_s (%f) should be between (0, 1) ", valid_s)

        if batch_s > tot_size * min(valid_s, 1 - valid_s):
            raise ValueError("batch_s (%d) > min(valid_s=%d, train_s=%d)" % (batch_s,
                             tot_size * valid_s, tot_size * (1 - valid_s)) )

        all_batches = range( tot_size / batch_s )
        if not (rng is None):
            rng.shuffle(all_batches)
        try:
            valid_batches = all_batches[valid_idx: valid_idx + int(len(all_batches) * valid_s)]
        except IndexError:
            raise ValueError("valid_idx (%d) should be between 0 and %d",
                             valid_idx, len(all_batches) - 1)
        train_batches = list( set(all_batches) - set(valid_batches) )
        assert set(train_batches + valid_batches) == set(all_batches)
        assert len( set(train_batches) & set(valid_batches) ) == 0

        if not (rng is None):
            rng.shuffle(train_batches)
            rng.shuffle(valid_batches)

        log.info("train batches: %s", str(train_batches))
        log.info("valid batches: %s", str(valid_batches))

        return (train_batches, valid_batches)

    def __iter_batches(self, which, nepochs):
        """infinite loop over train/valid batches

        :param nepochs: loop this many times over train batches
                        (0 will loop forever)
        :rtype: iterator """

        assert which in ("train_batches", "valid_batches")

        batches = getattr(self, which)
        epoch = 0
        while True:
            for i in batches:
                yield i
            epoch = epoch + 1
            if epoch == nepochs:
                break

    def iter_train(self, nepochs):
        return self.__iter_batches("train_batches", nepochs)

    def iter_valid(self, nepochs):
        return self.__iter_batches("valid_batches", nepochs)


class AnnotatedDataset(Dataset, TheanoShare):
    """dataset of 2d samples"""

    def __init__(self, X, Y, T):
        valY = None
        if not (Y is None):
            valY = Y.values
        valT = None
        if not (T is None):
            valT = T["label_code"].values

        Dataset.__init__(self, X.values, valY, valT)
        TheanoShare.__init__(self)
        self.pX, self.sY, self.dfT = X, Y, T

        self.label_names = None
        if not (self.T is None):
            self.label_names = np.unique( self.dfT["label_name"].values ).tolist()

    def dump(self, path):
        arch, key = archive.split(path)

        X, Y, T = self.pX, self.sY, self.dfT
        (nsamp, ntrack, width) = X.values.shape

        if not (X is None):
            archive.save_object( arch, "%s/rawX" % key, X )
            normX, meanX, sdX = self.normalize_features(X.values.reshape(nsamp, -1))

            archive.save_object( arch, "%s/X" % key,
                                 pd.Panel(normX.reshape(nsamp, ntrack, width),
                                          items=X.items,
                                          major_axis=X.major_axis,
                                          minor_axis=X.minor_axis) )
            archive.save_object( arch, "%s/meanX" % key,
                                 pd.DataFrame(meanX.reshape(ntrack, width),
                                              index=X.major_axis,
                                              columns=X.minor_axis ) )
            archive.save_object( arch, "%s/sdX" % key,
                                 pd.DataFrame(sdX.reshape(ntrack, width),
                                              index=X.major_axis,
                                              columns=X.minor_axis ) )

        if not (Y is None):
            archive.save_object( arch, "%s/Y" % key, Y )

        if not (T is None):
            archive.save_object( arch, "%s/T" % key, T )

    @classmethod
    def mean_sdX(cls, path):
        ap, did = archive.split(path)

        meanX = archive.load_object(ap, "/".join((did, "meanX")))
        sdX = archive.load_object(ap, "/".join((did, "sdX")))
        return meanX, sdX

    @classmethod
    def _from_archive(cls, path, raw, **kwargs):
        ap, did = archive.split(path)

        key = "%s/%s" % (did, (raw and "rawX" or "X"))
        X = archive.load_object(ap, key)

        def load_none(k, p=ap, did=did):
            try:
                return archive.load_object(p, "%s/%s" % (did, k))
            except Exception:
                pass
        Y, T = load_none("Y"), load_none("T")
        return cls(X, Y, T, **kwargs)


class AnchorDataset(AnnotatedDataset):
    def __init__(self, X, Y, T):
        """X, Y, T are a Panel, Series, and DataFrame resp."""

        AnnotatedDataset.__init__(self, X, Y, T)

    @property
    def track_names(self):
        return self.pX.major_axis.tolist()

    @property
    def tracks(self):
        return self.X.shape[1]

    @property
    def width(self):
        return self.X.shape[2]


class TrainAnchorDataset(AnchorDataset, TrainDataset):
    def __init__(self, X, Y, T, batch_s,
                 tot_s=None, valid_s=None, valid_idx=None, rng=None):

        AnchorDataset.__init__(self, X, Y, T)
        TrainDataset.__init__(self, batch_s, tot_s, valid_s, valid_idx, rng)

    @classmethod
    def _from_archive(cls, path, raw, batch_s, **kwargs):
        ap, did = archive.split(path)

        key = "%s/%s" % (did, (raw and "rawX" or "X"))
        X = archive.load_object(ap, key)

        ##TODO: find out what Exception is thrown
        ## for missing dataset members
        def load_none(k, p=ap, did=did):
            try:
                return archive.load_object(p, "%s/%s" % (did, k))
            except Exception:
                pass
        Y, T = load_none("Y"), load_none("T")
        return cls(X, Y, T, batch_s, **kwargs)


class TrainAnnotatedDataset(AnnotatedDataset, TrainDataset):
    def __init__(self, X, Y, T, batch_s,
                 tot_s=None, valid_s=None, valid_idx=None, rng=None):

        AnnotatedDataset.__init__(self, X, Y, T)
        TrainDataset.__init__(self, batch_s, tot_s, valid_s, valid_idx, rng)

    @classmethod
    def _from_archive(cls, path, raw, batch_s, **kwargs):
        ap, did = archive.split(path)

        key = "%s/%s" % (did, (raw and "rawX" or "X"))
        X = archive.load_object(ap, key)

        def load_none(k, p=ap, did=did):
            try:
                return archive.load_object(p, "%s/%s" % (did, k))
            except Exception:
                pass
        Y, T = load_none("Y"), load_none("T")
        return cls(X, Y, T, batch_s, **kwargs)

    @classmethod
    def random(cls, N, p, p1, batch_s, seed=None):
        """create a random dataset of N examples each of shape (p, p1)

        :param int N: number of examples
        :param int p: number of rows of each example
        :param int p1: number of cols of each example
        :param int batch_s: batch size for the dataset
        :param int seed: seed to the numpy random number generator
        :rtype: an instance of TrainAnnotatedDataset
        """
        import numpy as np
        import pandas as pd

        np.random.seed(seed)

        pX = pd.Panel(np.random.rand(N, p, p1))
        sY = pd.Series(np.random.rand(N))
        t = np.array(sY.values + 0.5, dtype=np.int)
        dfT = pd.DataFrame({"label_code": t,
                            "label_name": map(lambda v: ("A", "B")[v], t)})
        return TrainAnnotatedDataset(pX, sY, dfT, batch_s=batch_s,
                                     rng=np.random.RandomState(seed))

from _dataset_utils import Mnist
