
import unittest
import os
import datetime
import re
from itertools import product

import numpy as np
from theano.tensor.shared_randomstreams import RandomStreams

from train import UTrainer, BaseTrainer
from autoencoder import AEStack
import monitor
from dimer import data_tests, archive_tests
from dimer.nnet import config_spec, base_test_classes
from dimer.data import TrainAnnotatedDataset
from dimer import archive


class TestUTrainer(unittest.TestCase, base_test_classes.NpyTester):
    def setUp(self):
        self.mtr = config_spec.MtrainSpec._make((10, 0, 0, 0.1, 1,
                                                 0.0, 10, 2, 1))
        ds = data_tests.random_adataset(shape=(100, 4, 10), labeled=False)
        self.ds = TrainAnnotatedDataset(ds.pX, ds.sY, ds.dfT,
                                        self.mtr.batch_size)
        self.ds.flatten()
        self.seed = np.random.randint(5000, 10000)
        self.rng = np.random.RandomState(self.seed)
        self.thrng = RandomStreams(self.seed)
        self.model = AEStack(self.ds.X.shape[1],
                             (10, 5, 2), self.rng, self.thrng,
                             self.ds.X.dtype, 0)

    def _get_trainer(self, seed, train_strat=None):
        import random
        prefix = "".join(random.sample("abcdefghijklmonopqrstuvxyz", 4))
        train = UTrainer(self.ds, self.model, self.mtr,
                         monitor.DaeLearnMonitor, seed,
                         prefix=prefix)
        if train_strat:
            for l in range(len(self.model)):
                train.train_layer(l, train_strat, self.mtr.minepochs)
        return train

    def test_trainer_name(self):
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M")
        P = re.compile("tr_...._%s\d\d_%d_" % (timestamp, self.seed))
        strain = str(self._get_trainer(self.seed))
        print strain, P.pattern
        self.assertTrue(bool(P.match(strain)))

    def test_trainer_name_seed(self):
        pr = ("", "some_non_empty_word")
        for i in range(100):
            seed = np.random.randint(1000000)
            for pref, suf in product(pr, pr):
                tname = BaseTrainer.train_name(pref, seed, suf)
                self.assertEqual(seed, BaseTrainer.seed_from_train_name(tname))


    @unittest.skip("todo")
    def test_train_strategies(self):
        pass

    def test_load_state(self):
        train = self._get_trainer(self.seed, "min_epochs")
        mtp = type(self.model)
        with archive_tests.empty_archive() as (archp, dsn):
            self.ds.dump(archive.join(archp, dsn))
            train.save(archive.join(archp, dsn))
            ld_model = mtp._from_archive(archp,
                                         os.path.join(dsn, str(train)),
                                         self.rng, self.thrng, self.ds.X.dtype)
        self.assertEqual(len(ld_model), len(self.model))
        for (ldl, l) in zip(ld_model, self.model):
            ldw, ldb, ldbp = ldl.get_weights()
            w, b, bp = l.get_weights()
            self.assertEqualArray(w, ldw)
            self.assertEqualArray(b, ldb)
            self.assertEqualArray(bp, ldbp)


#class TestExperiment( unittest.TestCase ):
#    def setUp(self):
#        pass
#
#    def test_train_name(self):
#        "test experiment name generation and parsing"
#        cfgf = "/a/dir/a_file_a.cfg"
#        seed = 1234
#        p = TrainExperiment.parse_train_name(TrainExperiment.train_name(cfgf, seed))
#        self.assertEqual(p[0], os.path.splitext(os.path.basename(cfgf))[0])
#        self.assertEqual(p[2], str(seed))
#
#    @unittest.skip("todo")
#    def test_load_state(self):
#        "test that 0-layer state is equal to X"
#
#        pass
#
#class TestTrainExperiment( unittest.TestCase ):
#    def setUp(self):
#        pass
#
#    @unittest.skip("todo")
#    def test_init(self):
#        "test attributes are properly initialized"
#        pass
#
#    @unittest.skip("todo")
#    def test_epoch(self):
#        "test epoch"
#        pass
#
#    @unittest.skip("todo")
#    def test_isup(self):
#        "test is_up"
#        pass
#
#    @unittest.skip("todo")
#    def test_restore(self):
#        "test that params of epoch-patience are restored"
#        pass
#
#    @unittest.skip("todo")
#    def test_save(self):
#        "test that model is properly saved"
#        pass
#
#    @unittest.skip("todo")
#    def test_func(self):
#        "test memoization of functions"
#        pass
