
import unittest
import numpy as np
from dimer.ops import binned_signal, standardize, fit, get_valid_idx
from dimer.ops import inverse_hsine

class Test_utils(unittest.TestCase):
    def setUp(self):
        self.n = 500

    def assertAEq(self, a, b):
        self.assertEqual(a.shape, b.shape)
        self.assertTrue(np.all(a == b))

    def test_binned_signal(self):
        for i in range(500):
            print i
            L = np.random.randint(100) + 5
            print "\tL=%d" % L
            a = np.random.rand(L)
            bin_size = np.random.randint(L/3)
            if not bin_size:
                continue
            print "\tbin_size=%d" % bin_size
            bin_f = getattr(np, ("mean", "sum")[np.random.randint(2)])
            if L % bin_size:
                print "A"
                self.assertRaises(ValueError, binned_signal, a, bin_size, bin_f)
            else:
                print "B"
                ba = binned_signal(a, bin_size, bin_f)
                self.assertEqual(ba.shape, (L/bin_size,))
                for i in range(L / bin_size):
                    self.assertEqual(float(ba[i]),
                                     bin_f(a[i*bin_size:(i+1)*bin_size]))

    #@unittest.skip("todo")
    def test_standardize(self):
        "testing vector  standadrization"

        L = np.random.randint(100) + 5

        x = np.random.rand(L)
        sx, mx, vx = standardize(x)

        self.assertEqual(sx.shape, x.shape)
        self.assertEqual(mx, x.mean())
        self.assertEqual(vx, x.std())
        self.assertAlmostEqual(sx.mean(), 0)
        self.assertAlmostEqual(sx.std(), 1)

        for i in range(100):
            L = np.random.randint(100) + 5
            axis = np.random.randint(0, 2)
            x = np.random.rand(L, L/2)
            sx, mx, vx = standardize(x, axis=axis)

            self.assertEqual(sx.shape, x.shape)
            if np.abs(np.max(sx.mean(axis=axis))) > 0.1:
                print x.shape, axis
                print sx.mean(axis=axis)
                print (x - x.mean(axis=axis)).mean(axis=axis)
                print mx
                print
                print sx.std(axis=axis)
                print x.std(axis=axis)
            self.assertAlmostEqual(np.max(sx.mean(axis=axis)), 0)
            self.assertAlmostEqual(np.max(sx.std(axis=axis)-1), 0)


    @unittest.skip("todo")
    def test_fit(self):
        pass

    @unittest.skip("todo")
    def test_binning_effect(self):
        pass

    def test_valid_idx(self):
        x = map(lambda i: np.random.rand(10, 20), range(2))
        x[0][3,5] = np.nan
        x[0][3,10] = np.nan
        x[0][4,7] = np.nan
        x[1][1,6] = np.nan
        x[1][3,8] = np.nan

        a = get_valid_idx(x)
        self.assertEqual(a.shape[0], x[0].shape[0])
        for i in range(a.shape[0]):
            if i in (1, 3, 4):
                self.assertFalse(a[i])
            else:
                self.assertTrue(a[i])

    def test_smooth_hsine(self):
        x = np.random.rand(10, 20)
        sx = inverse_hsine(x)

        self.assertEqual(x.shape, sx.shape)
        for i in range(x.shape[0]):
            for j in range(x.shape[1]):
                if x[i,j] == 0:
                    self.assertEqual(sx[i,j], 0)
