#!/usr/bin/env python3
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this Software except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

import argparse
import asyncio
import binascii
import concurrent.futures
from datetime import datetime
from elasticsearch import Elasticsearch
import ipaddress
import json
import re
import requests
from requests.auth import HTTPDigestAuth
import sys

# Command line query for Moloch, including full packet search

parser = argparse.ArgumentParser(description="Query and full packet search for Moloch")

parser.add_argument("-q", "--query", help="Moloch query")
parser.add_argument("--wreck-the-cluster", help="Don't do this unless necessary", action="store_true")
parser.add_argument("-s", "--start", help="Start Date")
parser.add_argument("-e", "--end", help="End Date")
parser.add_argument("-f", "--fields", help="Comma delimited list of fields to return")
parser.add_argument("--headers", help="Add headers to the tsv", action="store_true")
parser.add_argument("--regex", help="Regex to use on packet")
parser.add_argument("--hexregex", help="Regex to use on the hexed packet")
parser.add_argument("--limit", help="Only produce this manyish results (rounded up to pagesize)", type=int)
parser.add_argument("-l", "--listfields", help="List allowed Moloch Fields for --fields", action="store_true")
parser.add_argument("-c", "--concurrent", help="Number of concurrent packet's to search", type=int, default=10)
parser.add_argument("-p", "--pagesize", help="Results to grab from elasticsearch at a time", type=int, default=1000)
parser.add_argument("-t", "--timeout", help="Elasticsearch scroll timeout", default="2m")
parser.add_argument("--esurl", help="Elasticsearch URL (proto://host:port)", required=True)
parser.add_argument("--molurl", help="Moloch node (proto://host:port)", required=True)
parser.add_argument("--apiuser", help="Moloch API username", required=True)
parser.add_argument("--apipass", help="Moloch API password", required=True)

args = parser.parse_args()

sess = requests.Session()  # Http(s) connection to Moloch for finding sessions
s = None  # Placeholder for worker http(s) sessions

# Elasticsearch session
es = Elasticsearch([args.esurl],
                   sniff_on_start=True,
                   sniff_on_connection_fail=True,
                   sniffer_timeout=60)

start_time = datetime.now()


#
# Log an entry to stderr
#
def log_entry(text):
    sys.stderr.write(str(datetime.now()) + " ")
    sys.stderr.write(text + "\n")


#
# Get field mapping from elasticsearch. (i.e. moloch field -> elasticsearch field)
#
def get_field_map():

    log_entry("Getting fields from elasticsearch")
    js = es.search(body={"query": {}},
                   index="fields",
                   request_timeout=120,
                   size=1000)
    fm = {}
    for hit in js["hits"]["hits"]:
        if hit["_source"].get("regex", False):
            continue
        key = hit["_id"]
        val = hit["_source"]["dbField2"]
        fm[key] = val
    log_entry("Done.")
    return fm


#
# Get ES query from moloch
#

def get_query(req):

    url = "{0}/api/buildquery".format(args.molurl)
    auth = HTTPDigestAuth(args.apiuser, args.apipass)
    res = sess.get(url, auth=auth, params=req)
    res.raise_for_status()
    js = json.loads(res.text)
    return js


#
# Print a progress indicator to stderr
#
def print_progress(cur, tot):
    if tot == 0:
        tot = 1
    pct = int(cur / tot * 100)
    now = datetime.now()
    elap = now - start_time
    remain = elap * (tot / cur) - elap
    log_entry("{0}% : {1} sessions of {2}, Elapsed: {3}, Remaining: {4}".format(pct, cur, tot, elap, remain))


#
# Print fields from sessions as a TSV to stdout
#
def print_tsv(hits, fields):

    for hit in hits:
        if hit is None:
            continue
        out = []
        for field in fields:
            # "id" is a special snowflake
            if field == "id":
                val = hit.get("_id", "")
            else:
                val = hit["_source"].get(field, "")

            if type(val) is list:
                val = ",".join(val)
            if type(val) is not str:
                val = str(val)

                # Convert IPs to ints. Not exhaustive.
            if field == "a1" or field == "a2":
                val = str(ipaddress.IPv4Address(int(val)))

            out.append(val)
        print("\t".join(out))


# Set up a fixed session per mulitprocess worker
def worker_init():
    global s
    s = requests.Session()


#
# Fetch packet (as PCAP) from Moloch. Search it if necessary.
#
def fetch_search_packet(x):
    data = {"ids": x["_id"]}
    url = "{0}/sessions.pcap".format(args.molurl)
    auth = HTTPDigestAuth(args.apiuser, args.apipass)
    res = s.get(url, auth=auth, params=data)
    phex = binascii.hexlify(res.content)
    x["_source"]["packet"] = phex.decode("utf-8")
    if args.regex:
        if re.search(args.regex.encode("utf-8"), res.content):
            return x
    elif args.hexregex:
        if re.search(args.hexregex.encode("utf-8"), phex):
            return x
    else:
        return x
    return


#
# Dispatch fetching of packets to some workers
#
async def packet_search(reg, hreg, data):

    worker_init()
    with concurrent.futures.ThreadPoolExecutor(max_workers=args.concurrent) as exe:
        loop = asyncio.get_event_loop()
        futs = [loop.run_in_executor(exe, fetch_search_packet, datum) for datum in data]
        ret = []
        for res in await asyncio.gather(*futs):
            ret.append(res)
        return ret


#
# Get sessions from ES and process
#
def process_sessions(data):

    js = es.search(body={"query": data["esquery"]["query"]},
                   index=data["indices"],
                   doc_type="session",
                   request_timeout=120,
                   size=args.pagesize,
                   scroll=args.timeout,
                   _source=process_sessions.db_fields)

    tot_size = js["hits"]["total"]

    # If no fields requested, just print total and exit
    if (len(process_sessions.db_fields) < 1):
        print(tot_size)
        sys.exit(0)

    sofar = size = len(js["hits"]["hits"])

    # XXX: Gross. All of this is duplicated below. Fix.
    if args.regex or args.hexregex or "packet" in process_sessions.raw_fields:
        loop = asyncio.get_event_loop()
        js["hits"]["hits"] = loop.run_until_complete(packet_search(reg=args.regex,
                                                     hreg=args.hexregex,
                                                     data=js["hits"]["hits"]))

    # Print tab delimited file
    print_tsv(js["hits"]["hits"], process_sessions.db_fields)
    print_progress(sofar, tot_size)

    if args.limit is not None and sofar >= args.limit:
        return
    sid = js["_scroll_id"]

    while size > 0:
        js = es.scroll(scroll_id=sid, scroll="5m", request_timeout=120)

        if args.regex or args.hexregex or "packet" in process_sessions.raw_fields:
            loop = asyncio.get_event_loop()
            js["hits"]["hits"] = loop.run_until_complete(packet_search(reg=args.regex,
                                                         hreg=args.hexregex,
                                                         data=js["hits"]["hits"]))

        size = len(js["hits"]["hits"])
        sofar += size

        # Print tab delimited file
        print_tsv(js["hits"]["hits"], process_sessions.db_fields)
        print_progress(sofar, tot_size)

        if args.limit is not None and sofar >= args.limit:
            return
        sid = js["_scroll_id"]


# "Static" variables for process_sessions
process_sessions.cur_entry = 0
process_sessions.headers_printed = False
process_sessions.db_fields = []
process_sessions.raw_fields = []


def main():

    moloch_sess_query = {"facets": 0, "length": args.pagesize, "strictly": 0}

    if args.regex and args.hexregex:
        sys.stderr.write("Only one of --regex or --hexregex allowed.\n")
        sys.exit(2)
    if not args.end:
        sys.stderr.write("No end time specified.\n")
        sys.exit(2)
    if not args.start:
        sys.stderr.write("No start time.\n")
        sys.exit(2)
    if not args.wreck_the_cluster:
        if not args.query:
            sys.stderr.write("No Moloch query. Use --wreck-the-cluster if this is REALLY what you want.\n")
            sys.exit(2)

    if args.start:
        moloch_sess_query["startTime"] = args.start
    if args.end:
        moloch_sess_query["stopTime"] = args.end
    if args.query:
        moloch_sess_query["expression"] = args.query

    if args.fields or args.listfields:
        field_map = get_field_map()
        # List Moloch fields, if requested
        if args.listfields:
            for key in field_map:
                print(key)
            sys.exit(0)
        # Convert moloch fields to ES db fields
        process_sessions.raw_fields = list(map(lambda x: x.strip(), args.fields.split(",")))
        for field in process_sessions.raw_fields:
            if field == "packet":
                process_sessions.db_fields.append("packet")
            else:
                try:
                    process_sessions.db_fields.append(field_map[field])
                except Exception:
                    sys.stderr.write("Field '{0}' not found.\n".format(field))
                    sys.exit(1)
        moloch_sess_query["fields"] = ",".join(process_sessions.db_fields)

    query = get_query(moloch_sess_query)
    if query.get("bsqErr", False):
        print(query["bsqErr"])
        sys.exit(1)

    process_sessions(query)

    log_entry("Finished. Query took " + (str(datetime.now() - start_time)) + "\n")


if __name__ == "__main__":
    main()
