#!/usr/bin/env python

"""Train a convolutional neural net for classification"""

import sys
import os
import argparse
import logging

import numpy as np

from dimer import data
from dimer.argutils import ArgTypes, attach_config
from dimer.archive import dset_path, DSPEC_MSG, split
from dimer.nnet import ProblemType, train, config_spec, nccn

logging.basicConfig(level=logging.INFO)


if __name__ != "__main__":
    print >> sys.stderr, "this is a script. cannot import"
    sys.exit(1)

parser = argparse.ArgumentParser(
    description=__doc__,
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    epilog="Olgert Denas (Taylor Lab)")
parser.add_argument("settings", type=ArgTypes.is_file,
                    help="Experiment file of settings.")
parser.add_argument("input", type=ArgTypes.is_dspath,
                    help="Input data." + DSPEC_MSG)
parser.add_argument("--tstrat", choices=("min_epochs", "max_epochs", "estop"),
                    default="estop",
                    help="Stop training at min/max nr. of epochs, or early stopping.")
parser.add_argument("--ptype", choices=ProblemType.choices,
                    default=ProblemType.choices[0],
                    help="Prediction type. classification or regression")
parser.add_argument("--repfreq", type=int, default=1,
                    help="Frequency with which to print reports")
parser.add_argument("--seed", type=ArgTypes.is_uint,
                    default=np.random.randint(1000000),
                    help="Seed of random number generator")
parser.add_argument("--valid_size", type=float, default=0.2,
                    help="Fraction of the data for validation")
parser.add_argument("--valid_idx", type=ArgTypes.is_uint, default=4,
                    help=("Index from where to slice validation data from."
                          "In [0, |batches| - |valid_batches|]") )
parser.add_argument("--raw", action='store_true', default=False,
                    help=("Use the raw version of the data."
                          "Raw data are fitted in [0,1], but not normalized."))
parser.add_argument("--resume", type=str, default="",
                    help=("Resume training a model. Simply initialize the "
                          "weights with values in the given training "
                          "snapshot"))
parser.add_argument("--prefix", type=str, default="",
                    help="Prefix of the experiment name")
attach_config(config_spec.CnnModelSpec, parser)
opt = parser.parse_args()
if opt.resume:
    opt.seed = train.CnnTrainer.seed_from_train_name(opt.resume)

## initialize random number generator
rng = np.random.RandomState(opt.seed)

## define model
tr, ms = map(lambda c: c._from_settings(opt.settings),
             (config_spec.MtrainSpec, config_spec.CnnModelSpec))
## overwrite modelspecs with non-null commandline settings
ms = ms._replace(**dict(filter(lambda t: t[1],
                        map(lambda f: (f, getattr(opt, f)), ms._fields))))
ms._check_consistency()

ds = data.TrainAnnotatedDataset._from_archive(opt.input, opt.raw, tr.batch_size,
                                              valid_s=opt.valid_size,
                                              valid_idx=opt.valid_idx,
                                              rng=rng)
model = nccn.CnnModel((ms.nkerns, ms.rfield, ms.pool), ms.lreg_size,
                      (tr.batch_size, 1) + ds.X.shape[1:],
                      ProblemType.ds_nout(ds, opt.ptype), rng,
                      xdtype=ds.X.dtype,
                      ydtype=ProblemType.ds_out(ds, opt.ptype).dtype)
if opt.resume:
    model.load(os.path.join(opt.input, opt.resume))

## train
lmon_cls = ProblemType.cnn_lmon(opt.ptype)
trainer = train.CnnTrainer(ds, model, tr, lmon_cls, opt.seed, opt.prefix)

try:
    trainer.train(opt.tstrat, opt.repfreq)
except KeyboardInterrupt:
    logging.warning("Ctrl+C captured. will save weights and exit")
## save
trainer.save(opt.input)
