'''
Created on Jul 15, 2013

@author: odenas
'''
import unittest
import numpy as np
rng = np.random.RandomState()

from . import constraint_boundary

class TestCnninv(unittest.TestCase):

    def test_showtxt(self):
        from . import draw_array
        thr=[1, 3, 5, 7, 9]
        rows = draw_array( np.array( range(10) ), thr=thr )
        erows = []
        for t in thr:
            r = (" " * (10-t)) + ("." * t)
            erows.append( r )
        print rows
        print erows
        self.assertEqual(rows, erows)

    def test_tup_prod(self):
        from . import tup_prod
        m = rng.randint(10)
        n = rng.randint(10)
        el = []
        for i in range(m):
            for j in range(n):
                el.append( (i, j) )

        self.assertEqual( list( tup_prod( (m, n) ) ), el )

        self.assertEqual( list( tup_prod( (m, n), True ) ),
                          list( reversed(el) ) )

    @unittest.SkipTest
    def test_margin(self):
        N = rng.randint(10)
        y = np.ones( (N,) )

        l = LayerCheck(np.ones( (N, ) ), y)
        self.assertEqual( l.ok_num, N )
        self.assertEqual( l.fail_num, N - l.ok_num )
        self.assertEqual( l.total_margin, 0 )

        y[:N/2] = -1
        l = LayerCheck(np.ones( (N, ) ), y)
        print l.margin
        print l.margin.shape
        print l.margin[l.margin >= 0]
        self.assertEqual( l.fail_num, N/2 )
        self.assertEqual( l.ok_num, N - l.fail_num )
        self.assertEqual( l.total_margin, -2 * (N/2) )

    def test_unpool(self):
        from . import unpool

        X = rng.rand(10, 100)
        p = (2, 2)
        uX = unpool(X, p)
        self.assertEqual(uX.shape, (X.shape[0]*p[1], X.shape[1]*p[0]))

        z = p[0] * p[1]
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                patchX = uX[i*p[0]:(i+1)*p[0], j*p[1]:(j+1)*p[1]]
                self.assertEqual(patchX.min(), X[i,j] / z)
                self.assertEqual(patchX.max(), X[i,j] / z)

        X = rng.rand(4, 10, 100)
        p = (2, 2)
        uX = unpool(X, p)
        self.assertEqual(uX.shape, (X.shape[0], X.shape[1]*p[0], X.shape[2]*p[1]))
        z = p[0] * p[1]
        for k in range(X.shape[0]):
            for i in range(X.shape[1]):
                for j in range(X.shape[2]):
                    patchX = uX[k, i*p[0]:(i+1)*p[0], j*p[1]:(j+1)*p[1]]
                    self.assertEqual(patchX.min(), X[k, i,j] / z)
                    self.assertEqual(patchX.max(), X[k, i,j] /  z)

    def test_sigmoids(self):
        from . import sigmoid, inv_sigmoid
        X = rng.rand( 10, 100 )

        self.assertTrue( np.all(sigmoid(inv_sigmoid(X))), X )

    @unittest.SkipTest
    def test_softmaxs(self):
        from . import softmax, inv_softmax
        X = rng.rand( 100 )

        self.assertTrue( np.all(softmax(inv_softmax(X))), X )

    def test_cb(self):
        t = (np.ones( (10,) ) * 0.5) + ( (rng.rand( 10 ) - .5) / 5 )

        s = rng.rand( 10 )
        ## set X = ub
        CB = constraint_boundary(s, t, np.ones((10,)), 0, 1)
        self.assertAlmostEqual( np.all( np.abs(CB - s) ), 0 )

        ## set X = lb
        CB = constraint_boundary(s, t, np.zeros((10,)), 0, 1)
        self.assertAlmostEqual( np.all( np.abs(CB - s) ), 0 )

        ## X = random
        x = (np.ones( (10,) ) * 0.5) + ( (rng.rand( 10 ) - .5)  )
        CB = constraint_boundary(s, t, x, 0, 1)
        for i in range(10):
            print t[i]
            print s[i], x[i]
            print CB[i]
            print
            if (t[i] >= x[i] and t[i] >= s[i]):
                ## same sign
                a = abs(t[i] - x[i]) / (t[i] - 0)
                b = abs(t[i] - s[i])
                self.assertEqual( CB[i], t[i] - a*b )
            elif  (t[i] < x[i] and t[i] < s[i]):
                ## same sign
                a = abs(t[i] - x[i]) / (1 - t[i])
                b = abs(t[i] - s[i])
                self.assertEqual( CB[i], t[i] + a*b )
            else:
                self.assertEqual( CB[i], t[i] )
