"""classes to handle configuration files"""

import os
import logging
import pickle
from ConfigParser import SafeConfigParser
from collections import namedtuple

import pandas as pd

from .. import archive

log = logging.getLogger(__name__)


def _open_cfg(path):
    cfg = SafeConfigParser()
    if not len(cfg.read([path])):
        raise ValueError("cannot load %s" % path)
    return cfg


def _check_exist_f(s):
    if not os.path.isfile(s):
        import warnings
        warnings.warn("non-existent input file '%s'" % s)
    return s


def _check_get_pairs(s, tp=int):
    "a tuple of pairs"
    items = map(tp, s.split())
    p = []
    npairs = len(items) / 2
    for i in range(npairs):
        p.append( (items[i * 2], items[i * 2 + 1]) )
    return tuple(p)


def _check_get_singles(s, tp=int):
    "a tuple"
    return tuple( map(tp, s.split()) )


class CfgFactory( object ):
    """abstract class with factory method from a config file

    to subclass, need to define:

      * _types: types for properties
      * _section: section on the confifg file"""

    @classmethod
    def _from_settings(cls, path, new=tuple.__new__, len=len):
        """make a new  object from a settins file. the check_consistency
        method is called just before returning the namedtuple"""

        if not len(cls._fields) == len(cls._types):
            fl = len(cls._fields)
            tl = cls._types
            raise ValueError("(%d attr. types) vs. (%d attributes) for %s" % (fl, tl,
                             str(cls)))

        cfg = _open_cfg(path)
        iterable = map(lambda (k, t): t( cfg.get(cls._section, k) ),
                       zip(cls._fields, cls._types))
        result = new(cls, iterable)
        if len(result) != len(cls._fields):
            raise TypeError('expected %d arguments, got %d' % (len(cls._fields),
                            len(result)))
        result._check_consistency()
        return result

    def _check_consistency(self):
        "check as much as you can that values of params make sense"

        pass

    def archive_dump(self, arch, trpath):
        "dump self into archive"
        log.info("saving to %s", archive.join(arch, os.path.join(trpath, self._section)))
        arch_repr = pd.Series(list(pickle.dumps(self, -1)), name=str(self))
        archive.save_object(arch, os.path.join(trpath, self._section), arch_repr)

    @classmethod
    def _from_archive(cls, arch, trpath):
        "load from archive"
        arch_repr = archive.load_object(arch, os.path.join(trpath,
                                                           cls._section))
        return pickle.loads( "".join(arch_repr.values) )


class MlpModelSpec(namedtuple('MlpMetaParams', ('hsizes',)), CfgFactory):
    __slots__ = ()

    _types = (_check_get_singles,)
    _section = "mlpmodel"

    def _check_consistency(self):
        if any(map(lambda v: v <= 0)):
            raise ValueError("size of hidden layers must be positive")

class CnnModelSpec (namedtuple('CnnMetaParams', ("nkerns rfield pool lreg_size")), CfgFactory):
    """Cnn model specification class"""

    __slots__ = ()

    _types = (_check_get_singles, _check_get_pairs, _check_get_pairs, int)
    _section = "cnnmodel"
    _help = ("number of kernels", "2d receptive field dimensions",
             "2d pooling dimensions", "fully connected layer output dimension")

    def _check_consistency(self):
        if len(self.nkerns) != len(self.pool) or len(self.nkerns) != len(self.rfield):
            raise ValueError(" len(self.nkerns) !=  len(self.pool) or len(self.nkerns) != len(self.rfield)")

    @property
    def cp_arch(self):
        return (self.nkerns, self.rfield, self.pool)


class AESpec(namedtuple("AEMetaParams", "ins hiddens clevel"), CfgFactory):
    """autoencoder specification class"""

    __slots__ = ()
    _types = (int, _check_get_singles, float)
    _section = "aemodel"

    def _check_consistency(self):
        pass


class MtrainSpec( namedtuple("MtrainSpec", "batch_size l1_rate l2_rate lr tau momentum_mult nepochs minepochs patience"), CfgFactory ):
    """Cnn model train specification class"""
    __slots__ = ()

    _types = (int, float, float, float, int, float, int, int, int)
    _section = "modtrain"

    def _check_consistency(self):
        if self.minepochs > self.nepochs:
            raise ValueError("minepochs (%d) > nepochs (%d)",
                             self.minepochs, self.nepochs)


class ExpSpec(namedtuple("experiment_spec", "arch dsn trid")):
    """keeps metadata about a training experiment"""

    __slots__ = ()
    _types = (_check_exist_f, str, str)
    _section = "train_exp"

    def _check_consistency(self):
        self.seed

    @property
    def seed(self):
        ":rtype: the seed used for this training experiment"

        from dimer.nnet.train import BaseTrainer
        return BaseTrainer.seed_from_train_name(self.trid)

    @property
    def ds_path(self):
        ":rtype: path to the dataset (in the form <archive_file_path>:<dsname>)"

        return archive.join(self.arch, self.dsn)

    @property
    def trid_path(self):
        ":rtype: path to the dataset (in the form <archive_file_path>:<dsname>)"

        return os.path.join(self.ds_path, self.trid)

    @property
    def abs_trid(self):
        ":rtype: path to the dataset (in the form <archive_file_path>:<dsname>/<trid>)"
        return os.path.join(self.dsn, self.trid)
