

import sys, os
import unittest
import tempfile
from operator import itemgetter
from functools import partial
import numpy as np

from monitor import LearnMonitor, WeightMonitor
import nccn
from .. import data_tests
from .. import data
from . import ProblemType as PT


rng = np.random.RandomState()


#@unittest.skip("for now")
class TestMonitor( unittest.TestCase ):
    def setUp(self):
        self.incr = map(LearnMonitor._make, (
            (0,0,0,0,0,0,0,0),
            (0,0,1,0,0,0,0,0),
            (0,0,2,0,0,0,0,0),
            (0,0,1,0,0,0,0,0),
            (0,0,2,0,0,0,0,0),
            (0,0,3,0,0,0,0,0)))

    def test_seqmonotonicity(self):
        self.assertTrue( LearnMonitor.is_min_up("traincost", 2, self.incr) )
        self.assertFalse( LearnMonitor.is_min_still("traincost", 2, self.incr[:-1]) )
        self.assertTrue( LearnMonitor.is_min_up("traincost", 1, self.incr[:-1]) )
        self.assertTrue( LearnMonitor.is_max_still("traincost", 2, self.incr[:-1]) )

    def test_reldiff(self):
        self.assertEqual( LearnMonitor.rel_diff(init=90, final=100), 0.1 )
        self.assertEqual( LearnMonitor.rel_diff(init=100, final=90), -0.1 )

    def test_learnmonitor_stats(self):
        ds = data_tests.random_adataset()
        tds = data.TrainAnchorDataset(ds.pX, ds.sY, ds.dfT, 4)
        print tds.train_batches
        l = LearnMonitor._from_fs(tds,
                                type("FakeModel", (object,), dict(lr=0.1)),
                                0,
                                cost_f=float,
                                ce_f=float,
                                mcl_f=float,
                                pt=PT.classification)
        self.assertEqual(l.epoch, 0)
        self.assertEqual(l.lrate, 0.1)
        self.assertEqual(l.traincost, np.mean(np.array(tds.train_batches)))
        self.assertEqual(l.validcost, np.mean(np.array(tds.valid_batches)))

        l = LearnMonitor._from_fs(tds,
                                type("FakeModel", (object,), dict(lr=0.1)),
                                0,
                                cost_f=lambda i: 1,
                                ce_f=lambda i: 1,
                                mcl_f=lambda i: 1,
                                pt=PT.classification)
        self.assertEqual(l.epoch, 0)
        self.assertEqual(l.lrate, 0.1)
        self.assertEqual(l.traincost, 1.0)
        self.assertEqual(l.validcost, 1)

    def test_weightmonitor_stats(self):
        bs = 6
        model = nccn.CnnModel([(2,), ((4,4),), ((2,2),)], 12,
                        (bs, 1, 11, 11),
                        2, rng,
                        "float64", "int32")

        l = WeightMonitor._from_model(0, 0, model)
        w, b = model[0].get_weights()
        self.assertEqual(l.epoch, 0)
        self.assertEqual(l.layer, 0)
        self.assertEqual(l.wshp, '32')
        self.assertEqual(l.wmin, np.min(w))
        self.assertEqual(l.wmax, np.max(w))
        self.assertEqual(l.wmedian, np.median(w))
        self.assertEqual(l.wsd, np.std(w))
