"""utilities for importing and creating notable datasets"""

import cPickle
import os
import numpy as np
import pandas as pd
import dimer
from dimer.nnet.config_spec import ExpSpec


class Mnist(object):
    url = "http://deeplearning.net/data/mnist/mnist.pkl.gz"

    @classmethod
    def load_from_pkl(cls, pkl_file):
        """load images and labels from the pickled MNIST dataset

        :param str pkl_file: pickled unzipped file as in Mnist.url
        :rtype: a pair of ndarrays (X, Y)
        """

        if not os.path.isfile(pkl_file):
            raise IOError("%s not found. download from %s" % (pkl_file,
                                                              cls.url))
        alldt = cPickle.load(open(pkl_file))
        # train(x, y), valid(x,y), test(x, y)
        x = np.concatenate((alldt[1][0], alldt[2][0]), axis=0)
        y = np.concatenate((alldt[1][1], alldt[2][1]), axis=0)
        return (x, y)

    @classmethod
    def _dump_all(cls, x, y, to_file, dsn):
        """dump images and labels into archive

        :param ndarray x: images
        :param ndarray y: digits
        :param str to_file: archive file
        :param str dsn: dataset name
        :rtype: None"""

        experiment = ExpSpec(to_file, dsn, "")
        # train(x, y), valid(x,y), test(x, y)
        X = pd.Panel(x.reshape((x.shape[0], 28, -1)))
        Y = pd.Series(y),
        T = pd.DataFrame({"label_code": np.asarray(y, dtype='int32'),
                          "label_name": map(str, y)})
        dt = dimer.data.AnnotatedDataset(X, Y, T)
        dt.dump(experiment.ds_path)

    @classmethod
    def _dump_digit_pair(cls, x, y, l1, l2, to_file, dsn):
        """dump images and labels from a pair of digits into archive

        :param ndarray x: images
        :param ndarray y: digits
        :param int l1: first digit
        :param int l2: second digit
        :param str to_file: archive file
        :param str dsn: dataset name
        :rtype: None"""

        X = pd.Panel(x[np.logical_or(y == l1, y == l2)].reshape((-1, 28, 28)))
        Y = pd.Series(y[np.logical_or(y == l1, y == l2)])
        T = pd.DataFrame({"label_code": np.asarray(y[np.logical_or(y == l1, y == l2)],
                                                   dtype='int32') / max((l1, l2)),
                          "label_name": map(str, y[np.logical_or(y == l1, y == l2)])})
        dtbin = dimer.data.AnnotatedDataset(X, Y, T)
        assert (int(dtbin.Y.min()), int(dtbin.Y.max())) == (min(l1, l2), max(l1, l2))
        experiment = ExpSpec(to_file, dsn, "")
        dtbin.dump(experiment.ds_path)
