#!/usr/bin/env python


"""split BED regions into equal-sized smaller BED regions"""

import sys
import argparse
import logging
from itertools import imap


logging.basicConfig(level=logging.INFO)
lg = logging.getLogger()
lg.warning("TODO: add pick bin-i/nbins option")


parser = argparse.ArgumentParser(
    description=__doc__,
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    epilog="Taylor Lab (odenas@emory.edu)")
parser.add_argument("input", type=str, help="Regions (>= BED3)")
parser.add_argument("osize", type=int, help="output size")
parser.add_argument("step", type=int, help="step / stride")
parser.add_argument("--range", nargs=2, type=int, default=[],
                    help="Limit output to elements in this range.")
parser.add_argument("--border", choices=("crop", "extend"), default="crop",
                    help="what to do if last region is smaller than osize")
opt = parser.parse_args()


if opt.step > opt.osize:
    lg.warning("stride is larger than slice size. might be loosing info")

def get_slices(st, en, size=opt.osize, step=opt.step, extend=opt.border=="extend"):
    """iterate over slices between `st` and `en`

    :param int st: starting point
    :param int en: ending point (excluded)
    :param int size: size of slices
    :param bool extend: whether to drop the last slice (if smaller than osize)
    :param int step: distance between starting points of 2 consecutive slices
    :rtype: iterator of pairs (start, end)"""

    return ((s, s + size) for s in xrange(st, en - (0 if extend else size), step))


parse_bedline = lambda l: l.strip().split()

if __name__ != '__main__':
    lg.error("this is a script do not import")
    sys.exit(1)

with open(opt.input) as fd:
    for region_idx, bed_comp in enumerate(imap(parse_bedline, fd)):
        start, end = int(bed_comp[1]), int(bed_comp[2])
        for i, (s, e) in enumerate(get_slices(start, end)):
            if opt.range and not (opt.range[0] <= s <= e <= opt.range[1]):
                continue
            if len(bed_comp) > 3:
                name = "%s_%d" % (bed_comp[3], i)
            else:
                name = "%d_%d" % (region_idx, i)
            print "\t".join([bed_comp[0], str(s), str(e), name] + bed_comp[4:])
