#!/usr/bin/env python

"""infer input template that conforms the given model"""

import logging
logging.basicConfig(level=logging.INFO)
import argparse
import os
from itertools import product, imap

import numpy as np
rng = np.random.RandomState()

from dimer.archive import dset_path, DSPEC_MSG, split
from dimer import data
from dimer.nnet.config_spec import CnnModelSpec
from dimer.nnet import nccn, ProblemType

from dimer.nnet.cnninv.gmodel import NormOpt
from dimer.nnet.cnninv.predictions import CnnMTempl
from dimer.nnet.cnninv import inv_sigmoid, inv_softmax

if __name__ != "__main__":
    raise ImportError("you cannot import this")


def simple_line((t, i)):
    "format to an output line"
    return (ds.pX.minor_axis[i], ds.pX.major_axis.tolist()[t],
            X[t, i], mean_sig[t, i], median[t, i])


log = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description=__doc__,
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                 epilog="Olgert Denas (Taylor Lab)")
parser.add_argument("input", type=dset_path,
                    #default="erythroid.h5:tfonly_R10000_B40_T3.5",
                    help="Input data." + DSPEC_MSG)
parser.add_argument("trainid", type=str,
                    #default="tr__20131011_142946_88429_",
                    help="train id")
parser.add_argument("--ptype", choices=ProblemType.choices,
                    default=ProblemType.choices[0],
                    help="Prediction type. classification or regression")
parser.add_argument("--ce", type=float, default=0.1,
                    help="CE of output. this will discriminate up/down genes.")
parser.add_argument("--stretch", type=float, default=-1,
                    help="Set the domain of input features around zero. -1 leaves it = CB")
parser.add_argument("--label", type=int, default=0, help="Label")
parser.add_argument("--layer", type=int, default=0, help="layer")
parser.add_argument("--join", action="store_true", default=False,
                    help="Join identical adjacent features.")
parser.add_argument("--un_std", action="store_true", default=False,
                    help="Multiply by sd and add mean of training data.")
parser.add_argument("--header", action="store_false", default=True,
                    help="wherther to print header.")
parser.add_argument("--output", type=str, default="/dev/stdout", help="Ouput")

opt = parser.parse_args()

#logging.getLogger("cnninv.gmodel").setLevel( logging.DEBUG )
#logging.getLogger("cnninv.predictions").setLevel( logging.DEBUG )

ds = data.AnnotatedDataset._from_archive( opt.input, False )
ms = CnnModelSpec._from_archive(split(opt.input)[0], os.path.join(split(opt.input)[1], opt.trainid))
CM = nccn.CnnModel(ms.cp_arch, ms.lreg_size,
                   (1, 1, ds.X.shape[1], ds.X.shape[2]), ds.labels, rng,
                   xdtype=ds.X.dtype, ydtype=ProblemType.ds_out(ds, opt.ptype).dtype
                   )
CM.load( "/".join( (opt.input, opt.trainid) ) )

PR = CnnMTempl(CM, ds, opt.label,
               ((-opt.stretch, opt.stretch),) + (((0.01, 0.99),) * (len(CM)-1)),
               (NormOpt,) * len(CM),
               ((inv_sigmoid,) * (len(CM)-1)) + (inv_softmax,),
               (NormOpt.OBJMIN,) * len(CM),
               (True,) * len(CM),
               (False,) * len(CM),
               (False,) * len(CM),
               (True,) * len(CM),
               (opt.stretch == -1,) + ((True,) * (len(CM)-1)),
               opt.ce)

if opt.layer != 0:
    parser.error("if txt is set, layer should be 0")
X = PR.features(opt.layer).reshape(ds.X.shape[1:])
mean_sig = np.mean(ds.X[ds.T == PR.label], axis=0)
median = np.median(ds.X[ds.T == PR.label], axis=0)

if opt.un_std:
    mean_std = ds.mean_sdX(opt.input)
    X = (X * mean_std[1].values) + mean_std[0].values

track_pos_iter = product(xrange(X.shape[0]), xrange(X.shape[1]))
line_iterator = imap(simple_line, track_pos_iter)

ofrm_t = "\t".join((ds.label_names[opt.label],
                    "%d", "%s", "%.4f", "%.4f", "%.4f"))
with open(opt.output, 'a') as ofd:
    if opt.header:
        print >> ofd, "\t".join(("label", "start", "track",
                                "score", "mean_sig", "med_sig"))
    for line in line_iterator:
        print >> ofd, ofrm_t % line
