"""module that automates model training"""

import logging
import re
import os.path
import datetime
from operator import concat, attrgetter
from collections import deque
import numpy as np
import theano
import theano.tensor as T

from nccn import CnnModel
from config_spec import CnnModelSpec, MtrainSpec
from monitor import LearnMonitor
from . import ProblemType, adjust_lr
from .. import data, archive

log = logging.getLogger(__name__)


class BaseTrainer(object):
    """a basic trainer that keeps references to dataset, model, seed"""

    def __init__(self, ds, mtr, model, seed, prefix):
        self.tr = mtr
        self.ds = ds
        self.model = model
        self.mcls = None

        ## initialize random stream here otherwise it will not
        ## appear in the train id
        if seed is None:
            seed = np.random.randint(0, 100000)
            log.warning("""seed not initialized. The one I will generate for you,
                        is very likely to be different from the one you used to
                        initialize dhe Training dataset. Thus, experiment is
                        unlikely to be reproducible""")
        log.info("SEED for this train is : %s", str(seed))
        #self.rng = np.random.RandomState(seed)
        self.__name_of_this_run = self.train_name(prefix, seed)

        self.min_epochs = self.tr.minepochs + self.tr.patience
        self.max_epochs = self.tr.nepochs

    def __str__(self):
        return self.__name_of_this_run

    @property
    def epoch(self):
        return len(self.learninfo)

    @staticmethod
    def seed_from_train_name(trid):
        """extract the seed used for this training

        :param str trid: a train name possibly generated through
            BaseTrainer.train_name
        :rtype: int. the seed
        """

        P = re.compile("(tr)_(.*)_(\d\d\d\d\d\d\d\d_\d\d\d\d\d\d)_(\d+)_.*")
        return int(P.match(trid).group(4))

    @staticmethod
    def train_name(prefix, seed, suffix=""):
        """an ID assigend to an experiment. this will correspond to a
        train subdir inside archive named
        **tr_<prefix>_<timestamp>_<seed>_<suffix>**"""

        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        seed = str(seed)
        return "tr_%s_%s_%s_%s" % (prefix, timestamp, seed, suffix)

    def save(self, dspath):
        """save the trained information, including the model

        :param str dspath: path saving location of the type archive:dataset
        :rtype: None"""

        arch, dsn = archive.split(dspath)
        self.model.save(os.path.join(dspath, str(self)))
        self.model.mspec.archive_dump(arch, os.path.join(dsn, str(self)))
        self.mcls._archive_dump(dspath, str(self), self.learninfo)


class GlobalSGDTrainer(object):
    """mixin that provides basic methods for SGD training, all layers
    at once. requires importing class to implement gdstep"""

    def __fixed_epoch_train(self, nepochs, rfreq):
        for i in range(nepochs):
            self.gdstep(i % rfreq == 0)

    def __early_stopping_train(self, rfreq):
        self.__fixed_epoch_train(min(self.min_epochs, self.max_epochs), rfreq)

        patience = self.tr.patience
        epoch = self.max_epochs
        for epoch in range(self.min_epochs, self.max_epochs):
            if self.is_up():
                log.info("validcost is up. Done training")
                assert self.last_params[0] == self.last_params[-patience - 1], \
                    ("len(last_params) is %d expected %d" % (len(self.last_params),
                                                             patience + 1))
                self.restore_last_params()
                break
            else:
                self.gdstep(epoch % rfreq == 0)
        if epoch == self.max_epochs - 1:
            log.info("Hit maxepoch (%d). Done training!", self.max_epochs)

    def train(self, strategy, rfreq):
        if strategy in ("min_epochs", "max_epochs"):
            self.__fixed_epoch_train(getattr(self, strategy), rfreq)
        elif strategy == "estop":
            self.__early_stopping_train(rfreq)
        else:
            raise ValueError("un-supported learning strategy: %s" % strategy)

    def is_up(self, what="validcost"):
        ## has `what` gone up in all last paticence epochs? (min(patience) > paticence-1)
        return self.mcls.is_min_up(what, self.tr.patience, self.learninfo)

    def func(self, what):
        "return compiled function"

        if not (what in ("ce", "acc", "cost", "grad")):
            raise ValueError("cannot compile function %s" % what)

        if what in self._compiled_f:
            return self._compiled_f[what]

        index = T.iscalar("batch_index")
        in_bs = self.ds.batch_size
        layers = self.model
        shX, shY = (self.ds.share("X", self.ds.X.shape),
                    ProblemType.ds_shout(self.ds, self.probtype))
        givens = {layers.X: shX[index * in_bs: (index + 1) * in_bs],
                  layers.Y: shY[index * in_bs: (index + 1) * in_bs]}
        tr = self.tr

        if what == "ce":
            f = theano.function(inputs=[index], outputs=layers.cost(0, 0),
                                givens=givens)
        elif what == "cost":
            f = theano.function(inputs=[index],
                                outputs=layers.cost(tr.l1_rate, tr.l2_rate),
                                givens=givens)
        elif what == "grad":
            params = reduce(concat, map(lambda l: l.get_params(), layers))
            f = theano.function([index],
                                outputs=T.grad(layers.cost(tr.l1_rate, tr.l2_rate),
                                               wrt=params),
                                givens=givens)
        elif what == "acc":
            if self.probtype == ProblemType.regression:
                f = self.func("cost")
            else:
                f = theano.function(inputs=[index],
                                    outputs=T.sum(T.neq(layers[-1].y_hat, layers.Y)),
                                    givens=givens)
        return self._compiled_f.setdefault(what, f)

    def restore_last_params(self, learnlog_lst=None):
        """last patience epochs were bad, so restore
        the parameters of the model from the first element of patence"""

        ## restore the best parameters
        ## (from patience-1 i.e., the first element of last_params)
        self.model.weights = self.last_params[0]
        log.info("Restored model from epoch %d",
                 (self.epoch - self.tr.patience - 1))
        ## update learn info as well
        self.learninfo = self.learninfo[:-self.tr.patience]


class MlpTrainer(BaseTrainer, GlobalSGDTrainer):
    def __init__(self, ds, model, mtr, lmon, seed, prefix=""):
        BaseTrainer.__init__(self, ds, mtr, model, seed, prefix)

        self.min_epochs = self.tr.minepochs + self.tr.patience
        self.max_epochs = self.tr.nepochs
        self.last_params = deque([], maxlen=self.tr.patience + 1)
        self.mcls = lmon
        self.learninfo = []
        self._compiled_f = {}
        self.probtype = ProblemType.parse(filter(lambda k: type(model[-1]) == ProblemType.top_layer[k],
                               ProblemType.top_layer)[0])

    def gdstep(self, report):
        """step in the direction of -gradient, adjust error rate and store params"""

        if not self.learninfo:
            self.learninfo.append(self.mcls._from_fs(self.ds, self.tr, 0,
                                                     self.func("cost"),
                                                     self.func("ce"),
                                                     self.func("acc"),
                                                     self.probtype))
        if report:
            log.info(self.learninfo[-1].report)

        #step on the direction of gradient
        self.model.update_params(self.ds.train_batches, self.func("grad"),
                                 self.tr.momentum_mult, self.tr.lr)

        #register new params
        self.last_params.append( self.model.get_weights() )

        #update_lr
        if len(self.learninfo) > self.tr.tau:
            new_lr = adjust_lr(map(attrgetter("validcost"), self.learninfo), self.tr.lr)
            self.tr = self.tr._replace(lr=new_lr)

        ## add monitor information
        self.learninfo.append(self.mcls._from_fs(self.ds, self.tr, self.epoch,
                                                 self.func("cost"),
                                                 self.func("ce"), self.func("acc"),
                                                 self.probtype))


class CnnTrainer(BaseTrainer, GlobalSGDTrainer):
    def __init__(self, ds, model, mtr, lmon, seed, prefix=""):
        BaseTrainer.__init__(self, ds, mtr, model, seed, prefix)

        self.min_epochs = self.tr.minepochs + self.tr.patience
        self.max_epochs = self.tr.nepochs
        self.last_params = deque([], maxlen=self.tr.patience + 1)
        self.mcls = lmon
        self.learninfo = []
        self._compiled_f = {}
        self.probtype = ProblemType.parse(filter(lambda k: type(model[-1]) == ProblemType.top_layer[k],
                                                 ProblemType.top_layer)[0])

    def gdstep(self, report):
        """step in the direction of -gradient, adjust
        error rate and store params"""

        if not self.learninfo:
            self.learninfo.append(self.mcls._from_fs(self.ds, self.tr, 0,
                                                     self.func("cost"),
                                                     self.func("ce"),
                                                     self.func("acc"),
                                                     self.probtype))
        if report:
            log.info(self.learninfo[-1].report)

        #step on the direction of gradient
        self.model.update_params(self.ds.train_batches, self.func("grad"),
                                 self.tr.momentum_mult, self.tr.lr)

        #register new params
        self.last_params.append( self.model.get_weights() )

        #update_lr
        if len(self.learninfo) > self.tr.tau:
            new_lr = adjust_lr(map(attrgetter("validcost"), self.learninfo), self.tr.lr)
            self.tr = self.tr._replace(lr=new_lr)

        ## add monitor information
        self.learninfo.append(self.mcls._from_fs(self.ds, self.tr, self.epoch,
                                                 self.func("cost"),
                                                 self.func("ce"), self.func("acc"),
                                                 self.probtype))

    def func(self, what):
        "return compiled function"

        if not (what in ("ce", "acc", "cost", "grad")):
            raise ValueError("cannot compile function %s" % what)

        if what in self._compiled_f:
            return self._compiled_f[what]

        index = T.iscalar("batch_index")
        in_bs = self.ds.batch_size
        layers = self.model
        shX, shY = (self.ds.share("X", (self.ds.X.shape[0], 1) + self.ds.X.shape[1:]),
                    ProblemType.ds_shout(self.ds, self.probtype))
        givens = {layers.X: shX[index * in_bs: (index + 1) * in_bs],
                  layers.Y: shY[index * in_bs: (index + 1) * in_bs]}
        tr = self.tr

        if what == "ce":
            f = theano.function(inputs=[index], outputs=layers.cost(0, 0),
                                givens=givens)
        elif what == "cost":
            f = theano.function(inputs=[index],
                                outputs=layers.cost(tr.l1_rate, tr.l2_rate),
                                givens=givens)
        elif what == "grad":
            params = reduce(concat, map(lambda l: l.get_params(), layers))
            f = theano.function([index],
                                outputs=T.grad(layers.cost(tr.l1_rate, tr.l2_rate),
                                               wrt=params),
                                givens=givens)
        elif what == "acc":
            if self.probtype == ProblemType.regression:
                f = self.func("cost")
            else:
                f = theano.function(inputs=[index],
                                    outputs=T.sum(T.neq(layers[-1].y_hat, layers.Y)),
                                    givens=givens)
        return self._compiled_f.setdefault(what, f)


class UTrainer(BaseTrainer):
    def __init__(self, ds, model, mtr, lmon, seed, prefix=""):
        BaseTrainer.__init__(self, ds, mtr, model, seed, prefix)

        self._compiled_f = {}
        self.last_params = deque([], maxlen=self.tr.patience + 1)
        self.mcls = lmon
        self.learninfo = []
        self.__orig_lr = mtr.lr

    def layer_learninfo(self, lidx):
        return filter(lambda r: r.layer == lidx, self.learninfo)

    def lepoch(self, lidx):
        return len(filter(lambda r: r.layer == lidx, self.learninfo))

    def func(self, what, lidx):
        "return compiled function"

        if not (what in ("rec", "cost", "grad")):
            raise ValueError("cannot compile function %s" % what)

        fkey = "%s%d" % (what, lidx)
        if fkey in self._compiled_f:
            return self._compiled_f[fkey]

        index = T.iscalar("batch_index")
        layers = self.model
        bs = self.ds.batch_size
        givens = {layers[0].input: self.ds.shX[index * bs: (index + 1) * bs]}
        tr = self.tr

        if what == "cost":
            f = theano.function(inputs=[index],
                                outputs=layers[lidx].cost(tr.l1_rate, tr.l2_rate),
                                givens=givens)
        elif what == "rec":
            f = theano.function(inputs=[index],
                                outputs=layers[lidx].cost(0, 0), givens=givens)
        elif what == "grad":
            f = theano.function(inputs=[index],
                                outputs=T.grad(layers[lidx].cost(tr.l1_rate, tr.l2_rate),
                                               wrt=layers[lidx].get_params()),
                                givens=givens)
        return self._compiled_f.setdefault(fkey, f)

    def gdstep(self, lidx, report):
        """step in the direction of -gradient, adjust error rate and store params"""

        W = self.model[lidx].get_weights()[0]
        if not self.layer_learninfo(lidx):
            self.learninfo.append(self.mcls._from_fs(self.ds, self.tr, 0, 0, lidx,
                                                     self.func("cost", lidx),
                                                     W.shape[0]))
        if report:
            log.info(self.learninfo[-1].report)

        #step on the direction of gradient
        self.model.update_params(self.ds.train_batches, self.func("grad", lidx),
                                 self.tr.momentum_mult, self.tr.lr, lidx)

        #register new params
        self.last_params.append( self.model.get_weights() )

        #update_lr
        if len(self.learninfo) > self.tr.tau:
            new_lr = adjust_lr(map(attrgetter("validcost"), self.layer_learninfo(lidx)),
                               self.tr.lr)
            self.tr = self.tr._replace(lr=new_lr)

        ## add monitor information
        self.learninfo.append(self.mcls._from_fs(self.ds, self.tr,
                                                 self.epoch, self.lepoch(lidx), lidx,
                                                 self.func("cost", lidx),
                                                 W.shape[0]))

    def is_up(self, lidx, what="validcost"):
        ## has `what` gone up in all last paticence epochs? (min(patience) > paticence-1)
        return self.mcls.is_min_up(what, self.tr.patience, self.layer_learninfo(lidx))

    def __fixed_epoch_train(self, lidx, nepochs, rfreq):
        for i in range(nepochs):
            self.gdstep(lidx, i % rfreq == 0)

    def __early_stopping_train(self, lidx, rfreq):
        self.__fixed_epoch_train(lidx, min(self.min_epochs, self.max_epochs), rfreq)

        patience = self.tr.patience
        epoch = self.max_epochs
        for epoch in range(self.min_epochs, self.max_epochs):
            if self.is_up(lidx):
                log.info("validcost is up. Done training")
                assert self.last_params[0] == self.last_params[-patience - 1], \
                    ("len(last_params) is %d expected %d" % (len(self.last_params),
                                                             patience + 1))
                self.restore_last_params(lidx)
                break
            else:
                self.gdstep(lidx, epoch % rfreq == 0)
        if epoch == self.max_epochs - 1:
            log.info("Hit maxepoch (%d). Done training layer %d!",
                     self.max_epochs, lidx)

    def train_layer(self, lidx, strategy, rfreq):
        if lidx != 0:
            self.tr = self.tr._replace(lr=self.__orig_lr)

        if strategy in ("min_epochs", "max_epochs"):
            self.__fixed_epoch_train(lidx, getattr(self, strategy), rfreq)
        elif strategy == "estop":
            self.__early_stopping_train(lidx, rfreq)
        else:
            raise ValueError("un-supported learning strategy: %s" % strategy)

    def restore_last_params(self, i, learnlog_lst=None):
        """last patience epoch were bad, so restore
        the parameters of the model from the first element of patence"""

        ## restore the best parameters
        ## (from patience-1 i.e., the first element of last_params)
        self.model.weights = self.last_params[0]
        log.info("Restored best model from epoch %d", (self.lepoch(i) - self.tr.patience))
        ## update learn info as well
        self.learninfo = self.learninfo[:-self.tr.patience]


class Experiment( object ):
    """an experiment class that provides the utilities for loading and saving
    information

    It is possible to create new experiments or laod one from previous runs"""

    def __init__(self, ds_path, ptp, trid):
        """instantiate an experiment

        :param str ds_path: path to a dataset
        :param str cfg_file: path to a config file
        :param ptp: problem type. see :class:`dimer.nnet.ProblemType`
        :param int/None seed: seed for the random state.
            if None will be generated
        :param str trid: train id, possibly of a previous experiment.
            will be generated if None. see :func:`Experiment.this_train_name`
        """

        self.__name_of_this_run = trid
        self.dspath = ds_path
        self.probtype = ProblemType.parse(ptp)

    def load_mspec(self):
        arch, dsname = archive.split(self.dspath)
        return CnnModelSpec._from_archive(arch, os.path.join(dsname, str(self)))

    @staticmethod
    def train_name(prefix, seed="noseed"):
        """an ID assigend to an experiment. this will correspond to a
        train subdir inside archive named
        **tr_<conf_file_name>_<timestamp>_<seed>**"""

        cfg_name = os.path.splitext(os.path.basename(prefix))[0]
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        seed = str(seed)
        return "tr_%s_%s_%s" % (cfg_name, timestamp, seed)

    @staticmethod
    def parse_train_name(trn, P=re.compile("tr_(.+)_(\d{8}_\d{6})_(\d+)")):
        return P.match(trn).group(1, 2, 3)

    @classmethod
    def _from_path(cls, ds_path, trainid, ptp):
        (cfg_name, ts, sd) = cls.parse_train_name(trainid)
        return Experiment(ds_path, cfg_name, ptp, int(sd), trid=trainid)

    def __str__(self):
        return self.__name_of_this_run


class TrainExperiment(Experiment):
    """a class that automates training of a cnn model"""

    def __init__(self, ds_path, cfg_file, ptp, seed, vsize, vidx):
        """instantiate this class

        :param ds_path: path to a dataset. usual format file:path_of_ds
        :param cfg_file: config file
        :param ptp: problem type (instantiate from nnet.ProblemType)
        :param seed: a number or None
        :param vsize: size of validation set
        :param vidx: index (in terms of batches) of the validation set"""

        ## initialize random stream
        if seed is None:
            seed = np.random.randint(0, 100000)
        log.info("SEED is : %s", str(seed))
        self.rng = np.random.RandomState(seed)

        Experiment.__init__(self, ds_path, ptp,
                            self.train_name(cfg_file, seed))

        self.tr = MtrainSpec._from_settings(cfg_file)
        self.ds = data.TrainAnchorDataset._from_archive(ds_path, False,
                                                        self.tr.batch_size,
                                                        valid_s=vsize,
                                                        valid_idx=vidx,
                                                        rng=self.rng)

        ms = CnnModelSpec._from_settings(cfg_file)
        self.model = CnnModel((ms.nkerns, ms.rfield, ms.pool), ms.lreg_size,
                              (self.tr.batch_size, 1, self.ds.tracks, self.ds.width),
                              ProblemType.ds_nout(self.ds, ptp), self.rng,
                              xdtype=self.ds.X.dtype,
                              ydtype=ProblemType.ds_out(self.ds, ptp).dtype)
        self.ms = ms

        self._compiled_f = {}
        self.last_params = deque([], maxlen=self.tr.patience + 1)

        self.linfo = [LearnMonitor._from_fs(self.ds, self.tr, 0,
                                            self.func("cost"),
                                            self.func("ce"),
                                            self.func("acc"),
                                            self.probtype)]

        self.min_epochs = self.tr.minepochs + self.tr.patience
        self.max_epochs = self.tr.nepochs

    @property
    def epoch(self):
        return len(self.linfo)

    def is_up(self, what="validcost"):
        ## has `what` gone up in all last paticence epochs? (min(patience) > paticence-1)
        return LearnMonitor.is_min_up(what, self.tr.patience, self.linfo)

    def restore_last_params(self):
        """last patience epoch were bad, so restore
        the parameters of the model from the first element of patence"""

        ## restore the best parameters
        ## (from patience-1 i.e., the first element of last_params)
        self.model.weights = self.last_params[0]
        log.info("Restored best model from epoch %d", (self.epoch - self.tr.patience))
        ## update learn info as well
        self.linfo = self.linfo[:-self.tr.patience]

    def gdstep(self, report):
        """step in the direction of -gradient, adjust error rate and store params"""

        if report:
            log.info(self.linfo[-1].report)

        #step on the direction of gradient
        self.model.update_params(self.ds.train_batches, self.func("grad"),
                                 self.tr.momentum_mult, self.tr.lr)
        #register new params
        self.last_params.append( self.model.get_weights() )

        #update_lr
        if len(self.linfo) > self.tr.tau:
            new_lr = adjust_lr(map(attrgetter("validcost"), self.linfo), self.tr.lr)
            self.tr = self.tr._replace(lr=new_lr)

        #save current learninfo
        self.linfo.append(LearnMonitor._from_fs(self.ds, self.tr, self.epoch,
                          self.func("cost"), self.func("ce"), self.func("acc"),
                          self.probtype))

    def share_dataset(self):
        "share dataset X and Y"

        return (self.ds.share("X", (-1, 1, self.ds.X.shape[1], self.ds.X.shape[2])),
                ProblemType.ds_shout(self.ds, self.probtype))

    def func(self, what):
        "return compiled function"

        if not (what in ("ce", "acc", "cost", "grad")):
            raise ValueError("cannot compile function %s" % what)

        if what in self._compiled_f:
            return self._compiled_f[what]

        index = T.iscalar("batch_index")
        in_bs = self.ds.batch_size
        layers = self.model
        dt_x, dt_y = self.share_dataset()
        givens = {layers.X: dt_x[index * in_bs: (index + 1) * in_bs],
                  layers.Y: dt_y[index * in_bs: (index + 1) * in_bs]}
        tr = self.tr

        if what == "ce":
            f = theano.function(inputs=[index], outputs=layers.cost(0, 0),
                                givens=givens)
        elif what == "cost":
            f = theano.function(inputs=[index],
                                outputs=layers.cost(tr.l1_rate, tr.l2_rate),
                                givens=givens)
        elif what == "grad":
            params = reduce(concat, map(lambda l: l.get_params(), layers))
            f = theano.function([index],
                                outputs=T.grad(layers.cost(tr.l1_rate, tr.l2_rate),
                                               wrt=params),
                                givens=givens)
        elif what == "acc":
            if self.probtype == ProblemType.regression:
                f = self.func("cost")
            else:
                f = theano.function(inputs=[index],
                                    outputs=T.sum( T.neq(layers[-1].y_hat, layers.Y) ),
                                    givens=givens)
        return self._compiled_f.setdefault(what, f)

    def save(self):
        "save model"

        arch, dsn = archive.split(self.dspath)
        self.model.save(os.path.join(self.dspath, str(self)))
        self.ms.archive_dump(arch, os.path.join(dsn, str(self)))
