#!/usr/bin/env python

"""Train a stack of autoencoders"""

import sys
import os
import argparse
import logging
from operator import attrgetter

import numpy as np
from theano.tensor.shared_randomstreams import RandomStreams

from dimer import data
from dimer.argutils import ArgTypes
from dimer.archive import dset_path, DSPEC_MSG, split
from dimer.nnet import ProblemType, train, config_spec, autoencoder, monitor

logging.basicConfig(level=logging.INFO)

lg = logging.getLogger()


def set_layer_weights((layer, path_)):
    try:
        np_weights = np.load(path_)
        par_names = map(attrgetter("name"), layer.get_params())
        layer.set_weights(map(lambda wn: np_weights[wn], par_names))
    except IOError:
        lg.warning("weights from %s cannot be loaded. skipping ...", path_)


if __name__ != "__main__":
    print >>sys.stderr, "this is a script. cannot import"
    sys.exit(1)

parser = argparse.ArgumentParser(
    description=__doc__,
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    epilog="Olgert Denas (Taylor Lab)")
parser.add_argument("input", type=dset_path, help="Input data." + DSPEC_MSG)
parser.add_argument("settings", type=str, help="Experiment file of settings.")
parser.add_argument("hiddens", nargs="+", type=ArgTypes.is_pint,
                    help="Sizes of hidden layers")

parser.add_argument("--noise", type=float, default=0.1,
                    help="Add noise to training samples. Must be in [0, 1]")
parser.add_argument("--tstrat", choices=("min_epochs", "max_epochs", "estop"),
                    default="estop",
                    help="Stop training at min/max nr. of epochs, or early stopping.")
parser.add_argument("--repfreq", type=int, default=1,
                    help="Frequency with which to print reports")
parser.add_argument("--seed", type=ArgTypes.is_uint,
                    default=np.random.randint(1000000),
                    help="Seed of random number generator")
parser.add_argument("--valid_size", type=float, default=0.2,
                    help="Fraction of the data for validation")
parser.add_argument("--valid_idx", type=int, default=4,
                    help=("Index from where to slice validation data from."
                          "In [0, |batches| - |valid_batches|]") )
parser.add_argument("--raw", action='store_true', default=False,
                    help=("Use the raw version of the data."
                          "Raw data are fitted in [0,1], but not normalized."))
parser.add_argument("--resume", type=str, default="",
                    help=("Resume training a model. Simply initialize the "
                          "weights with values in the given training "
                          "snapshot"))
parser.add_argument("--train_layers", nargs="+", type=ArgTypes.is_uint,
                    default=None,
                    help=("Indicate the indexes (0-based) of layers to train"))
parser.add_argument("--load_layers", nargs="+", type=str,
                    default=None,
                    help=("Indicate .npy files to be loaded and "
                          "initialize a layer to some specified array. Non-"
                          "existing paths will silently be ignored"))
parser.add_argument("--prefix", type=str, default="",
                    help="Prefix of the experiment name")
opt = parser.parse_args()

if not (0 <= opt.noise <= 1):
    parser.error("noise must be in [0, 1]")
if sorted(opt.hiddens, reverse=True) != opt.hiddens:
    lg.warning("hidden sizes do not decrease")
if opt.resume:
    opt.seed = train.UTrainer.seed_from_train_name(opt.resume)
if opt.train_layers is None:
    opt.train_layers = range(len(opt.hiddens))

## initialize random number generator
rng = np.random.RandomState(opt.seed)
thrng = RandomStreams(rng.randint(100000))


## define model
tr = config_spec.MtrainSpec._from_settings(opt.settings)
ds = data.TrainAnnotatedDataset._from_archive(opt.input, opt.raw, tr.batch_size,
                                              valid_s=opt.valid_size,
                                              valid_idx=opt.valid_idx,
                                              rng=rng)
ds.flatten()
if opt.resume:
    model = autoencoder.AEStack._from_archive(
        split(opt.input)[0],
        os.path.join(split(opt.input)[1], opt.resume),
        rng, thrng, ds.X.dtype, zero_clevel=False)
else:
    model = autoencoder.AEStack(ds.X.shape[1], opt.hiddens,
                                rng, thrng, ds.X.dtype, opt.noise)

if opt.load_layers:
    map(set_layer_weights, zip(model, opt.load_layers))

trainer = train.UTrainer(ds, model, tr, monitor.DaeLearnMonitor, opt.seed, prefix=opt.prefix)
for lidx in opt.train_layers:
    try:
        trainer.train_layer(lidx, opt.tstrat, opt.repfreq)
    except KeyboardInterrupt:
        logging.warning(("Ctrl+C captured. will save weights "
                         "and continue on next layer"))
    finally:
        trainer.save(opt.input)
