'''
Created on Jul 29, 2013

@author: odenas
'''

import unittest
import sys
import logging
logging.basicConfig()
log = logging.getLogger()

import numpy as np
rng = np.random.RandomState()

from dimer.nnet import nccn
from dimer.data import Dataset
from predictions import CnnMTempl
from gmodel import NormOpt

from . import inv_sigmoid, sigmoid, inv_softmax, softmax


class TmplTest( unittest.TestCase ):
    def setUp(self):
        self.nkerns = [rng.randint(2, 4) for i in range(rng.randint(1, 2))]
        self.lrsize = rng.randint(10, 11)
        self.ish = (1, 1, rng.randint(1, 2), rng.randint(100, 200))
        self.no = 2
        self.n = rng.randint(8, 10)

    def model(self):
        model = nccn.CnnModel((self.nkerns,
                              [(1, 2)] * len(self.nkerns),
                              [(1, 2)] * len(self.nkerns)), self.lrsize,
                              self.ish, self.no, rng, np.float64, "int32")
        print >>sys.stdout, str(model)
        ds = Dataset(rng.rand(self.n, self.ish[2], self.ish[3]),
                     rng.rand(self.n),
                     np.array( map(lambda i: i % 2, range(self.n)) ))
        return CnnMTempl(model, ds, rng.randint(0, self.no),
                         [(0,1)] * len(model),
                         [NormOpt] * len(model),
                         ([inv_sigmoid] * (len(model)-1)) + [inv_softmax],
                         [NormOpt.OBJMIN] * len(model),
                         [True] * len(model),
                         [False] * len(model),
                         [False] * len(model),
                         [True] * len(model),
                         [False] + ([True] * (len(model)-1)),
                         0.3)

    #@unittest.skip
    def test_cb_and_db(self):
        m = self.model()
        for i in range(len(m.model)):
            print i, m.model.ish[i], m.model.osh[i]
        for i in range(len(m.model)):
            print i, m.model.ish[i], m.model.osh[i]
            self.assertEqual( m.cb(i).shape, m.db(i).shape )

            ish = np.prod( m.model.osh[i] )

            self.assertEqual( np.prod(m.db(i).shape), ish )
            self.assertEqual( np.prod(m.cb(i).shape), ish )
            self.assertEqual( np.prod(m.cb(i).shape), ish )

    def test_sol(self):
        m = self.model()
        for i in range(len(m.model)):
            #print i, m.model.ish[i], m.model.osh[i]
            ish = np.prod(m.model.ish[i])
            self.assertEqual( np.prod(m.features(i).shape), ish )
