"""abstract tester methods and common test operations"""

import unittest
import tempfile
import numpy as np
rng = np.random.RandomState()
import theano
import theano.tensor as T
from .. import archive

class NpyTester( object ):
    def assertZeroArray(self, x):
        self.assertEqual( np.max( np.abs(x) ), 0 )

    def assertAlmostZeroArray(self, x):
        self.assertAlmostEqual( np.max( np.abs(x) ), 0 )

    def assertEqualArray(self, x, y):
        #print x
        #print y
        #print x - y
        self.assertZeroArray( x-y )

    def assertAlmostEqualArray(self, x, y):
        #print x
        #print y
        #print x - y
        self.assertAlmostZeroArray( x-y )

    def assertDifferentArray(self, x, y):
        #print x
        #print y
        #print x - y
        self.assertGreater( np.max( np.abs(x-y) ), 0 )

class ModelTester( unittest.TestCase, NpyTester ):
    def setUp(self):
        self.rng = rng
        self.__rw = None

    def get_model(self):
        raise NotImplementedError("should be overriden")

    def get_ouput(self):
        raise NotImplementedError("should be overriden")

    def get_input(self):
        raise NotImplementedError("should be overriden")

    def rnd_weights(self):
        "save copies of the weights of a model"

        if self.__rw is None:
            m = self.get_model()
            self.__rw = []
            for i in range( len(m) ):
                self.__rw.append( map(np.copy, m[i].get_weights() ) )
            self.__rw = tuple( self.__rw )
        return self.__rw

    def zero_model(self):
        m = self.get_model()

        for l in m:
            zw = map(lambda w: w-w, l.get_weights())
            l.set_weights( tuple(zw) )
        return m

    def _test_s_cost_(self):
        """dependence of cost on weight L1/L2 norms"""

        M = self.get_model()
        l1 = rng.randint(0, 10) / 10
        l2 = rng.randint(0, 10) / 10
        Y = T.lvector("Y")
        y = self.get_output()

        x = self.get_input()

        cf = theano.function(inputs=[M[0].input], outputs=M[-1].p_y_given_x)
        print cf(x)
        print y

        cf = theano.function(inputs=[M[0].input, Y], outputs=M.cost(Y, l1, l2))
        cf_00 = theano.function(inputs=[M[0].input, Y], outputs=M.cost(Y, 0, 0))
        w_summ = l1 * M.weight_norm("l1")
        w_summ_sq = l2 * M.weight_norm("l2")

        self.assertEqual( cf(x,y) - cf_00(x,y), w_summ + w_summ_sq )

        M = self.zero_model()
        cf = theano.function(inputs=[M[0].input, Y], outputs=M.cost(Y, l1, l2))
        cf_00 = theano.function(inputs=[M[0].input, Y], outputs=M.cost(Y, 0, 0))
        x = self.get_input()
        self.assertEqual( cf(x,y) , cf_00(x,y) )

    def _test_io(self):
        """model can save and load"""
        from nnet_tests import archive_model
        modela = self.get_model()
        modelb = self.get_model()

        are_eq = []
        for i in range(len(modela)):
            rw = modela[i].get_weights()
            zw = modelb[i].get_weights()
            for (r,z)  in zip(rw, zw):
                are_eq.append( np.all( r==z ) )
                if are_eq[-1]:
                    print r, z
                print r-z
            print
        self.assertFalse( all(are_eq) )

        modelb.load( archive_model(modela) )
        are_eq = []
        for i in range(len(modela)):
            rw = modela[i].get_weights()
            zw = modelb[i].get_weights()
            for (r,z)  in zip(rw, zw):
                are_eq.append( np.all( r==z) )
        self.assertTrue( all(are_eq) )


class CNNLayerTester( unittest.TestCase, NpyTester ):
    """test class for CPLayer, HiddenLayer and LogisticRegression Layer

    initializes a random number generator. Needs the get_layer method and self.x (input vector)
    to be defined and from that defines:

      - zero_weights,
      - rnd_weights, 
      - zero_layer and rnd_layer"""

    def setUp(self):
        self.rng = rng
        self.__zw = None
        self.__rl = None

    def get_layer(self):
        raise NotImplementedError("should be overriden")

    def get_input(self):
        raise NotImplementedError("should be overriden")

    def zero_input(self):
        i = self.get_input()
        return np.zeros( i.shape(), dtype = i.dtype )

    def zero_weights(self):
        "cached zero weights"

        if self.__zw is None:
            self.__zw = map(np.copy, self.rnd_weights())
            for i in range(len(self.__zw)):
                self.__zw[i]  -= self.__zw[i]
            self.__zw = tuple( self.__zw )
        return self.__zw

    def rnd_weights(self):
        "cached random weights"

        return self.rnd_layer().get_weights()

    def zero_layer(self):
        "zero layer"
        l = self.get_layer()
        l.set_weights( self.zero_weights() )
        return l

    def rnd_layer(self):
        if self.__rl is None:
            self.__rl = self.get_layer()
        return self.__rl

    def _weights_inrange(self, thr, widx):
        l = self.get_layer()
        self.assertLessEqual( np.max( np.abs(l.get_weights()[widx]) ), thr )

    def _test_init_(self):
        def _weights_speeds_eqto(l, wlst):
            for z,w,s in zip(wlst, l.get_weights(), l.get_speeds()):
                self.assertTrue( np.all(w == z) )
                self.assertTrue( np.all(s == z-z) )

        _weights_speeds_eqto( self.zero_layer(), self.zero_weights() )
        _weights_speeds_eqto( self.rnd_layer(), self.rnd_weights() )

        # biases ar all zero
        self.assertTrue( np.all(self.get_layer().get_weights()[1] == 0) )


    def _test_norms_(self):
        l = self.zero_layer()
        self.assertEqual( l.weight_norm("l1"), 0.0 )
        self.assertEqual( l.weight_norm("l2"), 0.0 )

        l = self.rnd_layer()
        wlst = map(np.abs, self.rnd_weights())

        self.assertEqual( l.weight_norm("l1"), sum(map(lambda _: _.sum(), wlst)) )
        self.assertEqual( l.weight_norm("l2"), sum(map(lambda _: (_**2).sum(), wlst)) )

    def _test_activation(self):
        pass


