"""module for handling generation and formating of input representations
present at an arbitrary layer of a deep model"""


import os
from collections import namedtuple
from operator import itemgetter

import numpy as np
import theano
from bx.bbi.bigwig_file import BigWigFile
from dimer.ops import binned_signal, inverse_hsine
from dimer import human_readable
import logging

lg = logging.getLogger(__name__)

class RawFrep(namedtuple("rfr", "chrom start end metalab")):
    """utilities for managing representation of a genome region

    metalab is a string of floats separated by RawFrep._mtlb_sep"""

    _mtlb_sep = " "
    _sep = "\t"
    _mtlb_frm = "%.1f"

    def argmax(self, n=1):
        """return the index of the label that takes the highest value.

        if n > 1 return the indexes of the n-labels with the highest values
        indexes are in increasing order

        :param int n: number of indexes to return
        :rtype: tuple of size n"""

        x = np.array(self.metalab.split(self._mtlb_sep))
        idx = np.argsort(x)[-n:]
        return tuple(idx.tolist())

    @property
    def nunits(self):
        ":rtype: number of units, i.e., size of the representation"

        return len(self.metalab.split(self._mtlb_sep))

    @property
    def tup_metalab(self):
        return self.metalab.split(self._mtlb_sep)

    @property
    def arr_metalab(self):
        return np.array(map(float, self.tup_metalab))

    def to_raw(self):
        """get a string representation that conforms to the `Raw` format"""

        ofrm = self._sep.join(("%s", "%d", "%d", "%s"))
        return ofrm % (self.chrom, self.start, self.end, self.metalab)

    @classmethod
    def nunits_infile(cls, fpath):
        with open(fpath) as fd:
            return cls.from_line(fd.readline()).nunits

    @classmethod
    def from_array(cls, chrom, start, end, x, error=None):
        """create from a genome interval and a ndarray representation --
        0-based half open. accepts error for this repr. if computed

        :param str chrom: chromosome
        :pram int start: start postion
        :pram int end: end postion
        :pram float error: reconstruction error of this representation
        :pram ndarray x: metalabel as an ndarray
        :rtype: RawFrep"""

        x_str = cls._mtlb_sep.join(map(lambda v: cls._mtlb_frm % float(v), x))
        if not (error is None):
            x_str += (cls._mtlb_sep + (cls._mtlb_frm % error))
        return cls._make((chrom, int(start), int(end), x_str))

    @classmethod
    def from_line(cls, line):
        """parse the row of a .raw file

        this type of file is in BED4 format, but the name field
        is a space-separated list of unit values (str). the i-th float is the
        i-th coordinate of the representation

        :param str line: the line to parse
        :param str metalab_sep: separator of values within a metalabel
        :rtype: a tuple of the type ((chrom, start, end), (str, ..., str))"""

        ch, st, en, metalab = line.rstrip().split(cls._sep)
        return cls._make((ch, int(st), int(en), metalab.rstrip()))

    @classmethod
    def from_line_ec(cls, line):
        """parse the row of a .raw file and perform metalab-format check

        this type of file is in BED4 format, but the name field
        is a space-separated list of unit values (str). the i-th float is the
        i-th coordinate of the representation.

        :param str line: the line to parse
        :param str metalab_sep: separator of values within a metalabel
        :rtype: a tuple of the type ((chrom, start, end), (str, ..., str))"""

        ch, st, en, raw_metalab = line.rstrip().split(cls._sep)
        metalab = cls._mtlb_sep.join(raw_metalab.split(cls._mtlb_sep))
        if metalab != raw_metalab:
            raise ValueError("line not in RawFrep format (%s != %s)",
                             metalab, raw_metalab)
        return cls._make((ch, int(st), int(en), metalab))

    def binarized(self, thr, nan_symb, shift_down=True):
        """return a binarized representation of this

        labels <= thr will be set to 0, all others to 1

        :param int nbins: nr. of bins
        :param str nan_symb: substitute nans with this symbol
        :param bool shift_down: labels start counting from 0
        :rtype: RawFrep instance with digitized metalabel"""

        aml = self.arr_metalab
        if np.any(np.isnan(aml)):
            out_label = [nan_symb] * aml.shape[0]
            return self._replace(metalab=self._mtlb_sep.join(out_label))
        if thr == "auto":
            if np.all(np.min(aml) == aml):
                thr = 0.0
            else:
                thr = float(0.5 * (np.min(aml) + np.max(aml)))
        elif thr == "max":
            thr = max(np.max(aml) - 0.0000001, 0)
        if (type(thr) == str):
            raise ValueError("unsupported thr(%s) type (%s)" % (str(thr),
                                                                str(type(thr))))
        new_ml = map(lambda v: '0' if float(v) <= thr else '1',
                     self.tup_metalab)
        return self._replace(metalab=self._mtlb_sep.join(new_ml))

    def digitized(self, nbins, nan_symb, shift_down=True):
        """return a digitized representation of this

        a value in the metalabel will be in (0 - 9 / nbins)

        :param int nbins: nr. of bins
        :param str nan_symb: substitute nans with this symbol
        :param bool shift_down: labels start counting from 0
        :rtype: RawFrep instance with digitized metalabel"""

        bins = np.array(range(0, 10 - 10 % nbins, 10 / nbins)) / 10.0
        lg.debug("bins: %s", str(bins.tolist() + ["inf"]))
        mtlb = self.metalab

        x = self.arr_metalab
        if np.any(np.isnan(x)):
            out_label = [nan_symb] * x.shape[0]
        else:
            out_label = map(str, np.digitize(x, bins) - (1 if shift_down else 0))
        return self._replace(metalab=self._mtlb_sep.join(out_label))

    def keep_units(self, idx):
        """create an instance like this one, with all
        but the selected labels removed

        :param iterable idx: 0-based indexes to remove
        :rtype: RawFrep"""

        idx = list(idx)
        lab_selector = itemgetter(*idx) if idx else lambda l: l
        ml_tup = self.tup_metalab
        new_ml = lab_selector(ml_tup)
        if len(idx) != 1:
            new_ml = self._mtlb_sep.join(new_ml)
        return self._replace(metalab=new_ml)


class BwTracks(object):
    """represents multiple bigWig signal tracks"""

    def __init__(self, fpath_list, bin_size, bin_f, dry_run):
        """instantiate from a set of paths to bigWig files

        :param iterable fpath_list: paths to bigwig files
        :param int bin_size: bin the bigwig signal
        :param callable bin_f: function that summarizes the binned signal
                (intoa number)
        :param bool dry_run: simulate function calls on fake signal"""

        self.fpaths = fpath_list
        self.ntracks = len(fpath_list)
        self.bin_size = bin_size
        self.bin_f = bin_f
        self.dr = dry_run

    def __iter_bw_hnd(self):
        """yield bigwid handlers"""

        for fpath in self.fpaths:
            lg.debug("opening %s ...", os.path.basename(fpath))
            yield BigWigFile(open(fpath))

    def __sigld(self, hnd, chrom, start, end):
        """load signal with the given coordinate from the given handle

        :param bx.bbi.bigwig_file.BigWigFile hnd: bigwig handle
        :param str chrom: chromosome
        :param int start: start
        :param int end: end
        :rtype: ndarray of the signal"""

        if self.dr:
            return np.zeros((end - start),)
        return inverse_hsine(hnd.get_as_array(chrom, start, end))

    def load_region(self, chrom, start, end):
        """load signal with the given coordinate from all tracks

        :param str chrom: chromosome
        :param int start: start
        :param int end: end
        :rtype: ndarray with one row / track"""

        xlst = []
        for hnd in self.__iter_bw_hnd():
            xlst.append(binned_signal(self.__sigld(hnd, chrom, start, end),
                                      self.bin_size, self.bin_f))
        X = np.array(xlst)
        assert X.shape == (self.ntracks, (end - start) / self.bin_size)

        lg.debug("%d, loaded signal %s for %s: [%s-%s]", os.getpid(),
                 str(X.shape), chrom,
                 human_readable(start), human_readable(end))
        return X


class TrackRepr(object):
    """class that computes representations of input tracks from a dAE"""

    def __init__(self, model, ds, batch_size, in_tracks, dry_run):
        """instantiate this class

        :param dimer.nnet.autoencoder.AEStack model: model
        :param dimer.data.Dataseti ds: dataset instance
        :param int batch_size: compute model output in batches of this size
        :param BwTracks in_tracks: BwTrack instance
        :param bool dry_run: dry run"""

        self.model = model
        self.batch_size = batch_size
        self.intr = in_tracks
        self.dr = dry_run


        T = theano.tensor
        sigmoid = theano.tensor.nnet.sigmoid
        h_var = T.matrix("hvar", dtype=ds.X.dtype)
        get_h = lambda (w, b, bp): sigmoid(T.dot(h_var, w) + b)
        get_h_prime = lambda (w, b, bp): sigmoid(T.dot(h_var, w.T) + bp)

        self.h_f = map(lambda t: theano.function(inputs=[h_var],
                                                 outputs=get_h(t)),
                       map(lambda l: l.get_params(), model))
        self.h_primef = map(lambda t: theano.function(inputs=[h_var],
                                                      outputs=get_h_prime(t)),
                            map(lambda l: l.get_params(), model))

    def reconstruction_error(self, X0):
        """compute model reconstruction and layer reconstruction errors

        :param ndarray X0: input
        :rtype: tuple of the type (reconstruction array,
                (model reconstr. error, l0 reconstr. error l1 reconstr. error ...))
        """

        model, h_f, h_primef = self.model, self.h_f, self.h_primef
        def err_f(x, y):
            return np.mean(np.abs(x-y), axis=1)

        hlst, rlst, elst = [], [], []
        for i in range(len(model)):
            X = (X0 if i == 0 else hlst[-1])#; print X.shape,
            H = h_f[i](X)#; print H.shape,
            Z = h_primef[i](H)#; print Z.shape
            hlst.append(H)
            elst.append(err_f(X, Z))
            #assert np.all(model.compute_state(i, X0, hidden=True) == H)
            #assert np.all(model.compute_state(i, X0,
            #    hidden=True)[np.logical_not(np.isnan(model.compute_state(i, X0,
            #        hidden=True)))] == H[np.logical_not(np.isnan(H))])
            #assert np.all(model.compute_state(i, X0, hidden=False) == Z)
            #np.all(model.compute_state(i, X0,
            #    hidden=False)[np.logical_not(np.isnan(model.compute_state(i, X0,
            #        hidden=False)))] == H[np.logical_not(np.isnan(Z))])

        #if not np.any(np.isnan(np.array(X0))):
        #    import pdb; pdb.set_trace()

        for i in reversed(range(len(model))):
            X = (hlst[-1] if i == len(model) - 1 else rlst[-1])#; print X.shape,
            R = h_primef[i](X)#; print R.shape
            rlst.append(R)

        #return (rlst[-1], (err_f(np.array(X0), rlst[-1]), ) + tuple(elst))
        return err_f(np.array(X0), rlst[-1])


    def repr_iter(self, X, layer, stride=1, errors=False):
        """compute the representation along the signal X in valid mode

        :param dimer.autoencoder.AEStack model: trained model
        :param ndarray X: input signal of shape (tracks, genome_slice_size)
        :param int batch_size: compute labels in theano-batches
        :param bool dry_run: but produce arrays of 0s instead
                    of true representations. error is set to 0 in this case
        :param bool errors: whether to compute errors. if not error is set to 0
        :rtype: tuple of the type (ndarray, error) = (repr, reconstr. error)
        """

        assert stride == 1, "stride=1 supported only (was %d)" % stride
        assert len(X.shape) == 2, str(X.shape)
        model = self.model
        W = model[0].get_weights()[0]
        Wo = model[layer].get_weights()[0]
        assert W.shape[0] % X.shape[0] == 0

        ## the last element of xlst is the slice along the genome (i.e., X)
        ## of size track_width
        track_width = W.shape[0] / X.shape[0]
        xlst = map(lambda i: X[:, i:(i + track_width)].flatten(),
                   range(X.shape[1] - track_width))
        lg.debug("xlst: %d", len(xlst))

        for i in range(0, len(xlst),  self.batch_size):
            bstart, bend = i, min(i + self.batch_size, len(xlst))
            x = (np.zeros((bend - bstart, Wo.shape[1])) if self.dr
                 else model.compute_state(layer, xlst[bstart: bend]))
            e = np.zeros((x.shape[0],))
            if errors and (not self.dr):
                e = self.reconstruction_error(xlst[bstart: bend])
            assert e.shape[0] == x.shape[0]
            #import pdb; pdb.set_trace()
            yield (x, e)

    def repr_of_region(self, layer, chrom, start, end, errors=False):
        """compute the representation of given genome region

        :param int layer: representation is the hidden state of this layer
        :param str chrom: chromosome
        :param int start: start position
        :param int end: start position
        :param bool errors: compute reconstruction error?
        :rtype: iterator of (ndarray if errors=False else (ndarray, float))
        """

        W = self.model[0].get_weights()[0]
        extend = self.intr.bin_size * W.shape[0] / self.intr.ntracks
        signal = self.intr.load_region(chrom, start, end + extend)
        bin_size = self.intr.bin_size
        repr_size = self.model[layer].get_weights()[0].shape[1]

        c = 0
        for i, (x, e) in enumerate(self.repr_iter(signal, layer, errors=errors)):
            assert x.shape[1] == repr_size, "%d != %d" % (x.shape[1],
                                                          repr_size)
            batch_start = start + (i * self.batch_size * bin_size)
            lg.debug("batch %d of size %d", i, x.shape[0])
            for j in range(x.shape[0]):
                cur_start = batch_start + (j * bin_size)
                yield (chrom, cur_start, cur_start + bin_size, x[j], e[j])
                c += 1
        assert c == (end - start) / bin_size, ("%d!=%d" % (c,
                                               (end - start) / bin_size))

#    @classmethod
#    def repr_to_str(cls, x, sep, frm):
#        return sep.join(map(lambda v: frm % float(v), x))
#
#    def region_to_raw(self, (chrom, start, end, x), sep=" ", frm="%.1f"):
#        ofrm = "\t".join(("%s" % chrom, "%d", "%d", "%s"))
#        x_str = self.repr_to_str(x, sep, frm)
#        return ofrm % (start, start + self.intr.bin_size, x_str)
