#!/usr/bin/env python
"""Combine signal in UCSC's big binary indexed format into
a pandas panel

(<nr. anchors> X <nr. of tracks> X <genome interval width>)
from feature overlap data.
Data is dumped on an HDF5 archive of the type

/<hdf_path>X   wide    (shape->[<# anchors>,<# tracks>,<width>])
/<hdf_path>Y   series  (shape->[<nr. of anchors>])

"""


import os
import sys
import logging
import argparse
import itertools
from operator import itemgetter
from functools import partial
from multiprocessing import Pool


import numpy as np
import pandas as pd

from bx.bbi.bigwig_file import BigWigFile

from dimer.argutils import ArgTypes
from dimer.genome import bedops, bed
from dimer import archive, data, ops

logging.basicConfig(level=logging.DEBUG)
lg = logging.getLogger()

if __name__ != '__main__':
    lg.error("this is a script do not import")
    sys.exit(1)

parser = argparse.ArgumentParser(
    description=__doc__,
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    epilog="Taylor Lab (odenas@emory.edu)")
parser.add_argument("anchors", type=ArgTypes.is_file,
                    help="Gene regions (>= BED6)")
parser.add_argument("input", nargs="+", type=ArgTypes.is_file,
                    help="Input features (bigWig or bw)")
parser.add_argument("output", type=ArgTypes.is_dspath,
                    help="Output file. " + archive.DSPEC_MSG)
parser.add_argument("--bin_size", type=ArgTypes.is_pint, default=10, help="bin size")
parser.add_argument("--bin_op", type=str, choices=("sum", "min", "max", "mean"),
                    default="sum", help="bin the signal using this function")
parser.add_argument("--par", type=ArgTypes.is_pint, default=1, help="parallelize")
parser.add_argument("--fit", action='store_true', default=False,
                    help="For each track. Fit signal in [0, 1]")
parser.add_argument("--smooth", action='store_true', default=False,
                    help=("Smooth with an inverse hyperbolic sine: "
                          "ln(x + sqrt(x^2 + 1))"))
parser.add_argument("--feature_score", action='store_true', default=False,
                    help=("Record scores of features as signal. In "
                          "this case input must be >= BED5"))
opt = parser.parse_args()

with open(opt.anchors) as fd:
    if opt.feature_score and len(fd.readline().strip().split()) < 5:
        parser.error("feature_score=True, but anchors is < BED5")

def make_overlaps(signal_f, o=opt):
    """represent features over anchors ovelapps as
    an data matrix anchor X width"""

    anchorf = o.anchors
    bin_size = o.bin_size
    bin_op = getattr(np, o.bin_op)

    binned_signal = lambda a: bin_op(a.reshape((-1, bin_size)), axis=1)

    Xlst = []
    parse_bed = partial(bedops.parseBED, use_score=False)
    with open(signal_f) as sig_fd:
        signal_hnd = BigWigFile(sig_fd)
        for genome_site in itertools.imap(parse_bed, open(anchorf)):
            Xlst.append(binned_signal(signal_hnd.get_as_array(*genome_site[:3])))
    lg.debug("%s: loaded data from %d anchors",
             os.path.basename(signal_f), len(Xlst))

    # (anchors X width )
    res_x = np.array( Xlst )
    lg.info("(%s) %s (min=%.4f, max=%.4f)...",
            os.path.basename(signal_f), str(res_x.shape), res_x.min(), res_x.max())
    return res_x

MAP = (opt.par > 1 and Pool(processes=opt.par).map or map)
xsig = MAP(make_overlaps, opt.input)
keep_idx = ops.get_valid_idx(xsig)
## remove the rows with NaN values
## (tracks X anchors X width)
xsig = np.rollaxis(np.array(map(lambda x: x[keep_idx], xsig)), 1, 0)
if opt.smooth:
    xsig = ops.inverse_hsine(xsig)
    lg.info("smoothing (min=%.4f, max=%.4f)...", xsig.min(), xsig.max())

if opt.fit:
    xsig = ( 1.0 + ops.fit(xsig) ) / 2.0
    lg.info("fitting (min=%.4f, max=%.4f)...", xsig.min(), xsig.max())
lg.info("%s (min=%.4f, max=%.4f)...", str(xsig.shape), xsig.min(), xsig.max())

## remove the items as indicated by keep_idx
Xitems = map(itemgetter(3), bed.BedReader(open(opt.anchors)))
Xitems = np.array(Xitems)[keep_idx].tolist()

Xmajor_axis = map(lambda n: os.path.splitext(n)[0],
                  map(os.path.basename, opt.input))

Xminor_axis = range(-opt.bin_size * xsig.shape[2] / 2,
                    opt.bin_size * xsig.shape[2] / 2, opt.bin_size)
X = pd.Panel(xsig, items=Xitems, major_axis=Xmajor_axis,
             minor_axis=Xminor_axis)

if opt.feature_score:
    Y = pd.Series(np.array(map(itemgetter(4),
                  bed.BedReader(open(opt.anchors))))[keep_idx])
else:
    Y = pd.Series(np.zeros((xsig.shape[0],)))

logging.getLogger("dimer.data").setLevel(logging.DEBUG)
data.AnchorDataset(X, Y, None).dump(opt.output)
