#!/usr/bin/env python
"""Interpretes longest alphanumeric (plus - and _) strings from stdin as IDs (of given type available in ENSEMBL) and translates those to the given target ID type using ENSEMBL BioMart."""
from __future__ import print_function

__author__ = "Johannes Koester"

import sys, argparse, csv, re
from collections import defaultdict
from functools import partial
import logging

if sys.version_info < (3,0):
	import httplib as http

	def readbytes(data):
		for l in data.read().split("\n"):
			yield l

else:
	import http.client as http

	def readbytes(data):
		for l in data:
			yield str(l, "iso8859-1")

#_ID_REGEXP = "[-_\w]+"
_ID_DELIMITER = "\t"

def id_candidates(id):
	""" Map id to lower and uppercase versions to ensure proper mapping """
	yield id
	if not id.isupper():
		yield id.upper()
	if not id.islower():
		yield id.lower()

class TableLookup:
	def __init__(self, table, keep = False):
		self._table = table
		self._idmap = dict()
		self._queue = set()
		self._orig_id = dict()
		self._keep = keep

	def enqueue_lookup(self, id):
		""" Add a possible id for lookup """
		if not id:
			return
		if id.lower() in self._table:
			self._idmap[id] = self._table[id.lower()]
			return
		for _id in id_candidates(id):
			if not _id in self._idmap:
				self._queue.add(_id)
				self._orig_id[_id] = id

	def map(self, id):
		""" Map id to the looked up target """
		format = _ID_DELIMITER.join if self._keep else lambda t: t[1]
		for i in id_candidates(id):
			try:
				return format((i, self._idmap[i]))
			except KeyError:
				pass

		return id

	def lookup(self):
		pass

	def has_full_queue(self):
		return False

	def has_empty_queue(self):
		return not self._queue

class BioMartLookup(TableLookup):
	_query = """<?xml version="1.0" encoding="UTF-8"?>
		<!DOCTYPE Query>
		<Query  virtualSchemaName = "default" formatter = "TSV" header = "0" uniqueRows = "0" count = "" datasetConfigVersion = "0.6" >

		<Dataset name = "{dataset}" interface = "default" >
			<Filter name = "{source}" value = "{ids}"/>
			<Attribute name = "{source}" />
			<Attribute name = "{target}" />
		</Dataset>
		</Query>"""


	@classmethod
	def set_server(cls, server):
		cls._biomart = http.HTTPConnection(server)


	def __init__(self, dataset, sources, target, overwrite = dict(), all = False, keep = False):
		TableLookup.__init__(self, overwrite, keep = keep)
		self._dataset = dataset
		self._sources = sources
		self._target = target
		self._all = all

	def lookup(self):
		""" Perform the actual lookup in BioMart for the queued ids """
		targetids = defaultdict(set)
		for source in self._sources:

			body = "query=" + self._query.format(dataset=self._dataset, source=source, target=self._target, ids=",".join(self._queue)) + "\n"
			logging.info("querying biomart")
			self._biomart.request("POST", "/biomart/martservice?", body=body)
			tsv = self._biomart.getresponse()
			tsv = csv.reader(readbytes(tsv), delimiter="\t")
			for i, l in enumerate(tsv):
				if l and l[0] in self._orig_id and l[1]:
					sourceid = self._orig_id[l[0]]
					targetids[sourceid].add(l[1])

		if self._all:
			select = lambda sourceid, targetids: _ID_DELIMITER.join(targetids)
		else:
			select = self.select_smallest_targetid
		for sourceid, _targetids in targetids.items():
			self._idmap[sourceid] = select(sourceid, _targetids)

	@staticmethod
	def select_smallest_targetid(sourceid, targetids):
		targetid = min((len(t), t) for t in targetids)[1]
		if len(targetids) > 1:
			print("Warning: Selecting {} from ambiguous target IDs for {}: {}".format(targetid, sourceid, ";".join(targetids)), file=sys.stderr)
		return targetid

	@classmethod
	def get_datasets(cls):
		""" Get the list of available datasets """
		cls._biomart.request("GET", "/biomart/martservice?type=datasets&mart=ensembl")
		response = cls._biomart.getresponse()
		for l in csv.reader(readbytes(response), delimiter="\t"):
			if len(l) > 1:
				yield "{}\t({})".format(*l[1:3])


	@classmethod
	def get_sources(cls, dataset):
		""" Get the list of available sources """
		cls._biomart.request("GET", "/biomart/martservice?type=filters&dataset={}".format(dataset))
		response = cls._biomart.getresponse()
		for l in csv.reader(readbytes(response), delimiter="\t"):
			if len(l) > 5 and l[5] == "id_list":
				yield "{}\t({})".format(*l[:2])

	@classmethod
	def get_targets(cls, dataset):
		""" Get the list of available targets """
		cls._biomart.request("GET", "/biomart/martservice?type=attributes&dataset={}".format(dataset))
		response = cls._biomart.getresponse()
		for l in csv.reader(readbytes(response), delimiter="\t"):
			if len(l) > 3 and l[3] == "feature_page":
				yield "{}\t({})".format(*l[:2])


def main():
	parser = argparse.ArgumentParser(description=__doc__)
	parser.add_argument("--server", default="www.ensembl.org", help="Local ensembl server (e.g. useast.ensembl.org)")
	parser.add_argument("--sources", "-s", nargs="+", help="Types of input IDs (e.g. UniProtKB).")
	parser.add_argument("--target", "-t", help="Type of output IDs (e.g. KEGG or NAME).")
	parser.add_argument("--dataset", "-d", default="hsapiens_gene_ensembl", help="An ENSEMBL dataset (default: hsapiens_gene_ensembl).")
	parser.add_argument("--listdatasets", "--ld", action="store_true", help="List the possible ENSEMBL datasets.")
	parser.add_argument("--listsources", "--ls", action="store_true", help="List the possible ID sources.")
	parser.add_argument("--listtargets", "--lt", action="store_true", help="List the available targets for a given source")
	parser.add_argument("--table", metavar="FILE", help="A file that contains a tab delimited map of IDs (will overwrite the decision of the lookup if used in combination with source and target).")
	parser.add_argument("--minlength", metavar="N", default=1, type=int, help="Minimum length for an input ID to be considered for translation.")
	parser.add_argument("--ignore-lowercase", "-i", action="store_true", help="Ignore lowercase words since they are likely no ids")
	parser.add_argument("--replace-by-all", "-a", action="store_true", help="Replace an ID by a tab separated list of all found alternative target IDs.")
	parser.add_argument("--keep-id", "-k", action="store_true", help="Keep the original id prepended and separated by a tab.")
	parser.add_argument("--id-regexp", "-r", default="[^\(\)\[\]\{\}\s,;]+", help="Python regular expression for ids to replace (default is non-whitespace: \S+).")
	args = parser.parse_args()

	logging.basicConfig(format="%(message)s", level=logging.INFO, stream=sys.stderr)

	BioMartLookup.set_server(args.server)

	if args.listdatasets:
		print("Available datasets:")
		for dataset in BioMartLookup.get_datasets():
			print(dataset)
	elif args.listsources:
		print("Available sources:")
		for source in BioMartLookup.get_sources(args.dataset):
			print(source)
	elif args.listtargets:
		print("Available targets:".format(args.listtargets))
		for target in BioMartLookup.get_targets(args.dataset):
			print(target)
	else:
		if args.table:
			table = dict(
			             ((l[0].lower(), l[1]) if len(l) > 1 else (l[0].lower(), l[0]))
			             for l in csv.reader(open(args.table), delimiter="\t")
			             if not l[0].startswith("#"))
		else:
			table = dict()

		# Select the lookup method based on the presence of sources and target
		if args.sources and args.target:
			db = BioMartLookup(args.dataset, args.sources, args.target, overwrite = table, all = args.replace_by_all, keep = args.keep_id)
		else:
			db = TableLookup(table, keep = args.keep_id)

		id_regexp = re.compile(args.id_regexp)
		linebuffer = []
		def replace():
			for l in linebuffer:
				l = id_regexp.sub(lambda match: db.map(match.group(0)), l)
				print(l, end="")
			del linebuffer[:]


		for l in sys.stdin:
			ids = id_regexp.findall(l)
			for id in ids:
				if len(id) >= args.minlength and (not args.ignore_lowercase or not id.islower()):
					db.enqueue_lookup(id)
			linebuffer.append(l)

			if db.has_full_queue():
				db.lookup()
				replace()

		if not db.has_empty_queue():
			db.lookup()
			replace()


if __name__ == '__main__':
	main()
