'''
Created on Jul 15, 2013

@author: odenas
'''

import unittest
from operator import attrgetter
import logging
logging.basicConfig()
log = logging.getLogger()

import numpy as np
rng = np.random.RandomState()
import theano

from dimer.nnet import nccn
from gmodel import NormOpt
from . import inv_sigmoid

class NormOptTest(unittest.TestCase):
    def setUp(self):
        pass

    def get_fflayer(self):
        W = rng.rand( rng.randint(40, 100), rng.randint(2, 30) ) - 0.5
        layer = nccn.HiddenLayer(theano.tensor.matrix("X"),
                                 W.shape[0], W.shape[1], rng, W.dtype)
        layer.set_weights( (W, np.zeros((W.shape[1],))) )
        return layer

    def bounds(self, model):
        for idx in model.var_idx_iter():
            v = model.xvars[idx]
            self.assertLessEqual(v.lb, v.getAttr("X") + 0.0000001)
            self.assertGreaterEqual(v.ub, v.getAttr("X") - 0.0000001)

    def sol(self, model, ocb, odb, parse_idx, actf, osh):
        sol = model.get_sol()
        ss = actf( sol.reshape( sol.shape ) ).reshape( osh )
        parse_idx = lambda s: tuple( map(int, s.split("_")) )
        for idx in map(parse_idx, map(attrgetter("ConstrName"), model.model.getConstrs())):
            if np.all(ocb[idx] > odb[idx]):
                self.assertGreaterEqual( ss[idx], ocb[idx] - 0.0000001 )
            if np.all(ocb[idx] < odb[idx]):
                self.assertLessEqual( ss[idx], ocb[idx] + 0.0000001 )
        self.assertEqual(len(map(attrgetter("ConstrName"), model.model.getConstrs())),
                         len(filter(bool, ocb.flatten() != odb.flatten())))

    def test_ff(self):
        layer = self.get_fflayer()
        W, b = layer.get_weights()
        actf = theano.function([layer.input], layer.activation())

        idb = 0.5 * np.ones( (1, W.shape[0]) )
        odb = actf( idb )
        icb = rng.rand(1, W.shape[0])
        ocb = actf( icb )

        model = NormOpt(layer, (idb, odb), (icb, ocb), rng.rand(*icb.shape),
                        (0, 1), inv_sigmoid,
                        NormOpt.OBJMIN, True, False, False, False, False)
        model.solve()
        model.get_sol()

        self.bounds(model)

        self.sol(model, ocb, odb, int, actf, ocb.shape)

    def test_cp(self, batch_size=1):
        ish = (batch_size, rng.randint(1, 10),  rng.randint(2, 8), rng.randint(10, 110))
        fsh = (2, ish[1], 2, 3)
        osh = (batch_size, fsh[0], ish[2] - fsh[2] + 1, ish[3] - fsh[3] + 1)
        pool = (1, 1)

        layer = self.get_cplayer(ish, fsh, pool)
        actf = theano.function([layer.input], layer.activation())

        idb = 0.5 * np.ones( ish )
        odb = actf( idb ).reshape( osh )
        icb = rng.rand( *ish )
        ocb = actf( icb ).reshape( osh )

        model = NormOpt(layer, (idb, odb), (icb, ocb), rng.rand(*icb.shape),
                        (0, 1), inv_sigmoid,
                        NormOpt.OBJMIN, True, False, False, False, False)
        model.solve()

        self.bounds(model)
        parse_idx = lambda s: tuple( map(int, s.split("_")) )
        self.sol(model, ocb, odb, parse_idx, actf, osh)

    def get_cplayer(self, ish, fsh, pool):
        (K, F, R, C) = fsh
        W = rng.rand( *fsh ) - 0.5
        b = np.zeros( (K,) )

        layer = nccn.ConvPoolLayer(theano.tensor.dtensor4("X"),
                                   fsh, ish,
                                   rng, pool, W.dtype)
        layer.set_weights( (W, b) )
        return layer
