
import unittest
import tempfile

from data import Dataset, AnchorDataset, TrainAnchorDataset
import numpy as np
import pandas as pd
from archive_tests import empty_archive
from archive import join

rng = np.random.RandomState(10)


def random_dataset(shape=(100, 80)):
    return Dataset(
        rng.rand( *shape ),
        rng.rand(shape[0]),
        np.concatenate([np.zeros((shape[0] / 2,), dtype=np.int32),
                        np.ones((shape[0] - shape[0] / 2,), dtype=np.int32)]))


def random_adataset(shape=(100, 8, 10), labeled=True):
    ds = random_dataset((shape[0], shape[1]*shape[2]))
    X = ds.X.reshape(shape)
    gnames = map(lambda i: "gene%d"%i, range(X.shape[0]))
    tracks  = map(lambda i: "track%d"%i, range(X.shape[1]))
    width = map(lambda i: "pos%d"%i, range(X.shape[2]))
    labels = ("R", "I")

    pX = pd.Panel(X, items=gnames, major_axis=tracks, minor_axis=width)
    dfT = pd.DataFrame({"label_code": ds.T,
                        "label_name": map(lambda v: labels[v], ds.T)})
    sY = pd.Series(ds.Y, index=gnames)

    if labeled:
        return AnchorDataset(pX, sY, dfT)
    return AnchorDataset(pX, sY, None)


def archive_ds(ds=None):
    if ds is None:
        ds = random_dataset()
    with empty_archive(tmp=False) as pdtup:
        ds.dump(join(*pdtup))
    return join(*pdtup)


def archive_ads(ads=None):
    return archive_ds( (random_adataset() if ads is None else ads) )


class TestDataset(unittest.TestCase):
    def setUp(self):
        self.X = rng.rand( 100, 800 )

    def test_normalize_feats(self):
        "0-mean and 1-variance features"

        def check_norm(x, m, v):
            self.assertEqual(m.shape, v.shape)
            self.assertEqual(m.shape, tuple(x.shape[1:]))
            print x.reshape(x.shape[0], -1)
            print m.reshape((-1,) )
            print v.reshape((-1,) )

            for f in range(x.shape[1]):
                self.assertAlmostEqual( x[:, f].mean(), 0)
                if v[f] != 0:
                    self.assertAlmostEqual( x[:, f].std(), 1)
                else:
                    self.assertAlmostEqual( x[:, f].std(), 0)

        X = self.X
        check_norm(*Dataset.normalize_features(X))
        X[:, 0] = 0
        print X[:, 0]
        check_norm( *Dataset.normalize_features(X) )

    def test_fit_feats(self):
        "features in [0,1]"

        print self.X
        fx = Dataset.fit_features(self.X)
        self.assertAlmostEqual( fx.min(), -1 )
        self.assertAlmostEqual( fx.max(), 1 )
        self.assertEqual(fx.shape, self.X.shape)

    def test_labeled(self):
        "test if it is labeled"
        x = rng.rand(3, 4)
        self.assertTrue( Dataset(x, rng.rand(3), rng.rand(3)).is_labeled )
        self.assertFalse( Dataset(x, rng.rand(3), None).is_labeled )


class TestAnchorDataset(unittest.TestCase):
    def setUp(self):
        self.labds = random_adataset()
        self.ds = random_adataset(labeled=False)
        self.bs = rng.randint(2, self.ds.X.shape[0] / 4)

    def test_annotations(self):
        "test annotations"
        self.assertEqual(set(self.labds.label_names),
                         set(self.labds.dfT["label_name"].tolist()))
        self.assertEqual(self.labds.track_names,
                         self.labds.pX.major_axis.tolist())

        self.assertEqual(self.ds.label_names, None)
        self.assertEqual(self.ds.track_names,
                         self.labds.pX.major_axis.tolist())

    def test_share(self):
        "dataset on theano shared vars"

        self.assertTrue( np.all( self.ds.shX.get_value() == self.ds.X ) )


    def test_batch_allocation(self):
        "test batch allocation"
        ds = TrainAnchorDataset(self.labds.pX, self.labds.sY, self.labds.dfT,
                                self.bs)

        ## by default, valid batches are at the end of the data
        ## they should also be in order (not shuffled by default)
        ## because rng=None
        self.assertEqual(ds.train_batches + ds.valid_batches,
                         range(ds.n_batches))

        ## enable shuffling
        ds = TrainAnchorDataset(self.labds.pX, self.labds.sY, self.labds.dfT,
                                self.bs, rng=rng)
        self.assertNotEqual(ds.train_batches + ds.valid_batches,
                            range(ds.n_batches))
        self.assertEqual(set(ds.train_batches + ds.valid_batches),
                         set(range(ds.n_batches)) )


    def test_batch_iter(self):
        "test batch iteration"
        ds = TrainAnchorDataset(self.labds.pX, self.labds.sY, self.labds.dfT,
                                self.bs, rng=rng)

        self.assertEqual( 5 * list( ds.iter_train(1) ), 5 * ds.train_batches )
        self.assertEqual( 7 * list( ds.iter_valid(1) ), 7 * ds.valid_batches )

        self.assertEqual( list( ds.iter_train(5) ), 5 * ds.train_batches )
        self.assertEqual( list( ds.iter_valid(7) ), 7 * ds.valid_batches )

    def test_aio(self):
        "test anchordataset IO"
        lds = self.labds

        ods = AnchorDataset._from_archive(archive_ads(lds), True)
        self.assertEqual(ods.label_names, lds.label_names)
        self.assertEqual(ods.track_names, lds.track_names)
        self.assertAlmostEqual(np.max( np.abs( ods.X - lds.X ) ), 0)
        self.assertTrue( np.all(ods.sY == lds.sY) )
        print ods.dfT
        print lds.dfT
        self.assertTrue( np.all(ods.dfT == lds.dfT) )

        ods = AnchorDataset._from_archive(archive_ads(lds), False)
        ldsX, m, sd = lds.normalize_features(lds.X.reshape(self.ds.X.shape[0], -1))
        self.assertAlmostEqual(np.max(np.abs(ods.X - \
                               ldsX.reshape(ods.X.shape)) ), 0)


    def test_tio(self):
        "test trainanchordataset IO"
        lds = TrainAnchorDataset(self.labds.pX, self.labds.sY,
                                 self.labds.dfT, self.bs)
        ods = TrainAnchorDataset._from_archive(archive_ds(lds), False, self.bs)
        self.assertEqual(ods.train_batches, lds.train_batches)
        self.assertEqual(ods.valid_batches, lds.valid_batches)

    def test_information(self):
        "information"

        lenses = np.array(map(lambda l: map(int, l.split()),
                              """1  1  1  1  1  3
                              2  1  1  1  2  2
                              3  1  1  2  1  3
                              4  1  1  2  2  1
                              5  1  2  1  1  3
                              6  1  2  1  2  2
                              7  1  2  2  1  3
                              8  1  2  2  2  1
                              9  2  1  1  1  3
                              10 2  1  1  2  2
                              11 2  1  2  1  3
                              12 2  1  2  2  1
                              13 2  2  1  1  3
                              14 2  2  1  2  2
                              15 2  2  2  1  3
                              16 2  2  2  2  3
                              17 3  1  1  1  3
                              18 3  1  1  2  3
                              19 3  1  2  1  3
                              20 3  1  2  2  1
                              21 3  2  1  1  3
                              22 3  2  1  2  2
                              23 3  2  2  1  3
                              24 3  2  2  2  3""".split("\n")))
        ds = Dataset(lenses[:,1:5], lenses[:,5], lenses[:,5])
        self.assertEqual(ds.target_information(), 1.3260875253642983)

        self.assertEqual(ds.attr_entropy(0, None), 1.2866910217181771)
