"""MultiQC submodule to parse output from Picard HsMetrics"""

import logging
from collections import defaultdict

import re
from typing import Any, Dict, List, Optional, Set, cast

from multiqc import config
from multiqc.base_module import BaseMultiqcModule
from multiqc.modules.picard import util
from multiqc.plots import linegraph, table
from multiqc.plots.table_object import ColumnDict, TableConfig
from multiqc.types import ColumnKey

# Initialise the logger
log = logging.getLogger(__name__)

FIELD_DESCRIPTIONS = {
    "AT_DROPOUT": "A measure of how undercovered <= 50% GC regions are relative to the mean.",
    "BAIT_DESIGN_EFFICIENCY": "Target territory / bait territory. 1 == perfectly efficient, 0.5 = half of baited bases are not target.",
    "BAIT_SET": "The name of the bait set used in the hybrid selection.",
    "BAIT_TERRITORY": "The number of bases which have one or more baits on top of them.",
    "FOLD_80_BASE_PENALTY": 'The fold over-coverage necessary to raise 80% of bases in "non-zero-cvg" targets to the mean coverage level in those targets.',
    "FOLD_ENRICHMENT": "The fold by which the baited region has been amplified above genomic background.",
    "GC_DROPOUT": "A measure of how undercovered >= 50% GC regions are relative to the mean.",
    "GENOME_SIZE": "The number of bases in the reference genome used for alignment.",
    "HET_SNP_Q": "The Phred Scaled Q Score of the theoretical HET SNP sensitivity.",
    "HET_SNP_SENSITIVITY": "The theoretical HET SNP sensitivity.",
    "HS_LIBRARY_SIZE": "The estimated number of unique molecules in the selected part of the library.",
    "HS_PENALTY_{coverage}X": "The 'hybrid selection penalty' incurred to get 80% of target bases to {coverage}X. This metric should be interpreted as: if I have a design with 10 megabases of target, and want to get {coverage}X coverage I need to sequence until PF_ALIGNED_BASES = 10^7 * 100 * HS_PENALTY_{coverage}X.",
    "MAX_TARGET_COVERAGE": "The maximum coverage of reads that mapped to target regions of an experiment.",
    "MEAN_BAIT_COVERAGE": "The mean coverage of all baits in the experiment.",
    "MEAN_TARGET_COVERAGE": "The mean coverage of targets.",
    "MEDIAN_TARGET_COVERAGE": "The median coverage of targets.",
    "NEAR_BAIT_BASES": "The number of PF aligned bases that mapped to within a fixed interval of a baited region, but not on a baited region.",
    "OFF_BAIT_BASES": "The number of PF aligned bases that mapped to neither on or near a bait.",
    "OLD_80_BASE_PENALTY": 'The fold over-coverage necessary to raise 80% of bases in "non-zero-cvg" targets to the mean coverage level in those targets.',
    "ON_BAIT_BASES": "The number of PF aligned bases that mapped to a baited region of the genome.",
    "ON_BAIT_VS_SELECTED": "The percentage of on+near bait bases that are on as opposed to near.",
    "ON_TARGET_BASES": "The number of PF aligned bases that mapped to a targeted region of the genome.",
    "PCT_EXC_BASEQ": "The fraction of aligned bases that were filtered out because they were of low base quality.",
    "PCT_EXC_DUPE": "The fraction of aligned bases that were filtered out because they were in reads marked as duplicates.",
    "PCT_EXC_MAPQ": "The fraction of aligned bases that were filtered out because they were in reads with low mapping quality.",
    "PCT_EXC_OFF_TARGET": "The fraction of aligned bases that were filtered out because they did not align over a target base.",
    "PCT_EXC_OVERLAP": "The fraction of aligned bases that were filtered out because they were the second observation from an insert with overlapping reads.",
    "PCT_OFF_BAIT": "The percentage of aligned PF bases that mapped neither on or near a bait.",
    "PCT_PF_READS": "PF reads / total reads. The percent of reads passing filter.",
    "PCT_PF_UQ_READS_ALIGNED": "PF Reads Aligned / PF Reads.",
    "PCT_PF_UQ_READS": "PF Unique Reads / Total Reads.",
    "PCT_SELECTED_BASES": "On+Near Bait Bases / PF Bases Aligned.",
    "PCT_TARGET_BASES_{coverage}X": "The fraction of all target bases achieving {coverage}X or greater coverage.",
    "PCT_USABLE_BASES_ON_BAIT": "The number of aligned, de-duped, on-bait bases out of the PF bases available.",
    "PCT_USABLE_BASES_ON_TARGET": "The number of aligned, de-duped, on-target bases out of the PF bases available.",
    "PF_BASES": "The number of bases in the PF reads.",
    "PF_BASES_ALIGNED": "The number of PF unique bases that are aligned with mapping score > 0 to the reference genome.",
    "PF_READS": "The number of reads that pass the vendor's filter.",
    "PF_UNIQUE_READS": "The number of PF reads that are not marked as duplicates.",
    "PF_UQ_BASES_ALIGNED": "The number of bases in the PF aligned reads that are mapped to a reference base. Accounts for clipping and gaps.",
    "PF_UQ_READS_ALIGNED": "The number of PF unique reads that are aligned with mapping score > 0 to the reference genome.",
    "TARGET_TERRITORY": "The unique number of target bases in the experiment where target is usually exons etc.",
    "TOTAL_READS": "The total number of reads in the SAM or BAM file examine.",
    "ZERO_CVG_TARGETS_PCT": "The fraction of targets that did not reach coverage=1 over any base.",
}


def parse_reports(module: BaseMultiqcModule) -> Set[str]:
    """Find Picard HsMetrics reports and parse their data"""

    data_by_bait_by_sample: Dict[str, Dict[str, Dict[str, Any]]] = dict()

    # Go through logs and find Metrics
    for f in module.find_log_files("picard/hsmetrics", filehandles=True):
        s_name: Optional[str] = f["s_name"]
        keys: Optional[List[str]] = None
        commadecimal: Optional[bool] = None
        baits: Set[str] = set()

        for line in f["f"]:
            maybe_s_name: Optional[str] = util.extract_sample_name(
                module,
                line,
                f,
                picard_tool="CollectHsMetrics",
                sentieon_algo="HsMetricAlgo",
            )
            if maybe_s_name:
                s_name = maybe_s_name
                keys = None

            if util.is_line_right_before_table(line, picard_class="HsMetrics", sentieon_algo="HsMetricAlgo"):
                keys = cast(List[str], f["f"].readline().strip("\n").split("\t"))
                if s_name in data_by_bait_by_sample:
                    log.debug(f"Duplicate sample name found in {f['fn']}! Overwriting: {s_name}")
                data_by_bait_by_sample[s_name] = dict()

            elif keys:
                vals = line.strip("\n").split("\t")
                if len(vals) != len(keys):
                    keys = None
                    continue

                bait = "NA"
                if keys[0] == "BAIT_SET":
                    bait = vals[0]
                data_by_bait_by_sample[s_name][bait] = dict()
                baits.add(bait)
                # Check that we're not using commas for decimal places
                if commadecimal is None:
                    commadecimal = False
                    for i, k in enumerate(keys):
                        if "PCT" in k or "BAIT" in k or "MEAN" in k:
                            if "," in vals[i]:
                                commadecimal = True
                                break
                for i, k in enumerate(keys):
                    try:
                        if commadecimal:
                            vals[i] = vals[i].replace(".", "")
                            vals[i] = vals[i].replace(",", ".")
                        data_by_bait_by_sample[s_name][bait][k] = float(vals[i])
                    except ValueError:
                        data_by_bait_by_sample[s_name][bait][k] = vals[i]

        for bait in baits:
            s_bait_name = f"{s_name}: {bait}"
            module.add_data_source(f, s_bait_name, section="HsMetrics")

    # Remove empty dictionaries
    for s_name in data_by_bait_by_sample:
        for bait in data_by_bait_by_sample[s_name]:
            if len(data_by_bait_by_sample[s_name][bait]) == 0:
                data_by_bait_by_sample[s_name].pop(bait, None)
        if len(data_by_bait_by_sample[s_name]) == 0:
            data_by_bait_by_sample.pop(s_name, None)

    data_by_sample: Dict[str, Dict[str, Any]] = dict()
    # Manipulate sample names if multiple baits found
    for s_name in data_by_bait_by_sample:
        for bait in data_by_bait_by_sample[s_name]:
            s_bait_name = s_name
            # If there are multiple baits, append the bait name to the sample name
            if len(data_by_bait_by_sample[s_name]) > 1:
                s_bait_name = f"{s_name}: {bait}"
            if s_bait_name in data_by_sample:
                log.debug(f"Duplicate sample name found in {f['fn']}! Overwriting: {s_bait_name}")
            data_by_sample[s_bait_name] = data_by_bait_by_sample[s_name][bait]

    # Filter to strip out ignored sample names
    data_by_sample = module.ignore_samples(data_by_sample)
    if len(data_by_sample) == 0:
        return set()

    # Superfluous function call to confirm that it is used in this module
    # Replace None with actual version if it is available
    module.add_software_version(None)

    # Write parsed data to a file
    module.write_data_file(data_by_sample, f"multiqc_{module.id}_HsMetrics")

    # Swap question marks with -1
    for s_name in data_by_sample:
        if data_by_sample[s_name]["FOLD_ENRICHMENT"] == "?":
            data_by_sample[s_name]["FOLD_ENRICHMENT"] = -1

    # Add to general stats table
    _general_stats_table(module, data_by_sample)

    # Add report section
    module.add_section(
        name="Hybrid-selection metrics",
        anchor=f"{module.id}_hsmetrics",
        description="Parsed from Picard HsMetrics tool that takes a SAM/BAM file input and collects metrics that are specific for sequence datasets generated through hybrid-selection. Hybrid-selection (HS) is the most commonly used technique to capture exon-specific sequences for targeted sequencing experiments such as exome sequencing.",
        plot=table.plot(
            data_by_sample,
            _get_table_headers(),
            pconfig=TableConfig(
                id=f"{module.id}_hsmetrics_table",
                namespace="HsMetrics",
                scale="RdYlGn",
                min=0,
                title="Picard HsMetrics",
            ),
        ),
    )
    tbases = _add_target_bases(module, data_by_sample)
    module.add_section(
        name=tbases["name"],
        anchor=tbases["anchor"],
        description=tbases["description"],
        plot=tbases["plot"],
    )
    hs_pen_plot = hs_penalty_plot(module, data_by_sample)
    if hs_pen_plot is not None:
        module.add_section(
            name="Hybrid-selection penalty",
            anchor=f"{module.id}_hsmetrics_hs_penalty",
            description='The "hybrid selection penalty" incurred to get 80% of target bases to a given coverage.',
            helptext="""
                Can be used with the following formula:

                ```
                required_aligned_bases = bait_size_bp * desired_coverage * hs_penalty
                ```
            """,
            plot=hs_pen_plot,
        )

    # Return the number of detected samples to the parent module
    return set(data_by_sample.keys())


def _general_stats_table(module: BaseMultiqcModule, data: Dict[str, Any]):
    """
    Generate table header configs for the General Stats table,
    add config and data to the base module.
    """
    # Look for a user config of which table columns we should use
    picard_config = getattr(config, "picard_config", {})
    genstats_table_cols = picard_config.get("HsMetrics_genstats_table_cols", [])
    genstats_table_cols_hidden = picard_config.get("HsMetrics_genstats_table_cols_hidden", [])

    headers: Dict[str, ColumnDict] = {}
    # Custom general stats columns
    if len(genstats_table_cols) or len(genstats_table_cols_hidden):
        for k, v in _generate_table_header_config(genstats_table_cols, genstats_table_cols_hidden).items():
            headers[k] = v

    # Default General Stats headers
    else:
        headers["FOLD_ENRICHMENT"] = {
            "title": "Fold Enrichment",
            "min": 0,
            "format": "{:,.0f}",
            "scale": "Blues",
            "suffix": " X",
        }
        headers["MEDIAN_TARGET_COVERAGE"] = {
            "title": "Median Target Coverage",
            "description": "The median coverage of reads that mapped to target regions of an experiment.",
            "min": 0,
            "suffix": "X",
            "scale": "GnBu",
        }
        try:
            covs = picard_config["general_stats_target_coverage"]
            assert isinstance(covs, list)
            assert len(covs) > 0
            covs = [str(i) for i in covs]
            log.debug(f"Custom picad coverage thresholds: {', '.join([i for i in covs])}")
        except (KeyError, AttributeError, TypeError, AssertionError):
            covs = ["30"]
        for c in covs:
            headers[f"PCT_TARGET_BASES_{c}X"] = {
                "rid": f"{module.id}_target_bases_{c}X",
                "title": f"Target Bases ≥ {c}X",
                "description": f"Percent of target bases with coverage ≥ {c}X",
                "max": 100,
                "min": 0,
                "suffix": "%",
                "format": "{:,.0f}",
                "scale": "RdYlGn",
                "modify": util.multiply_hundred,
            }
    module.general_stats_addcols(data, headers, namespace="HsMetrics")


def _get_table_headers() -> Dict[ColumnKey, ColumnDict]:
    # Look for a user config of which table columns we should use
    picard_config = getattr(config, "picard_config", {})
    HsMetrics_table_cols = picard_config.get("HsMetrics_table_cols")
    HsMetrics_table_cols_hidden = picard_config.get("HsMetrics_table_cols_hidden")

    # Default table columns
    if not HsMetrics_table_cols:
        HsMetrics_table_cols = [
            "AT_DROPOUT",
            "BAIT_DESIGN_EFFICIENCY",
            "BAIT_TERRITORY",
            "FOLD_80_BASE_PENALTY",
            "FOLD_ENRICHMENT",
            "GC_DROPOUT",
            "HET_SNP_Q",
            "HET_SNP_SENSITIVITY",
            "NEAR_BAIT_BASES",
            "OFF_BAIT_BASES",
            "ON_BAIT_BASES",
            "ON_TARGET_BASES",
            "PCT_USABLE_BASES_ON_BAIT",
            "PCT_USABLE_BASES_ON_TARGET",
            "PF_BASES",
            "PF_BASES_ALIGNED",
            "PF_READS",
            "PCT_SELECTED_BASES",
            "PF_UNIQUE_READS",
            "PF_UQ_BASES_ALIGNED",
            "PF_UQ_READS_ALIGNED",
            "TOTAL_READS",
            "MAX_TARGET_COVERAGE",
            "MEAN_BAIT_COVERAGE",
            "MEAN_TARGET_COVERAGE",
            "MEDIAN_TARGET_COVERAGE",
            "ON_BAIT_VS_SELECTED",
            "TARGET_TERRITORY",
            "ZERO_CVG_TARGETS_PCT",
        ]
    if not HsMetrics_table_cols_hidden:
        HsMetrics_table_cols_hidden = [
            "BAIT_TERRITORY",
            "TOTAL_READS",
            "TARGET_TERRITORY",
            "AT_DROPOUT",
            "GC_DROPOUT",
        ]

    return _generate_table_header_config(HsMetrics_table_cols, HsMetrics_table_cols_hidden)


def _generate_table_header_config(table_cols: List[str], hidden_table_cols: List[str]) -> Dict[ColumnKey, ColumnDict]:
    """
    Automatically generate some nice table header configs based on what we know about
    the different types of Picard data fields.
    """
    title_cleanup = [
        ("CVG", "coverage"),
        ("UQ", "unique"),
        ("ON_", "On-"),
        ("OFF_", "Off-"),
        ("NEAR_", "Near-"),
        ("_", " "),
        ("PCT", ""),
    ]

    # Warn if we see anything unexpected
    for c in table_cols + hidden_table_cols:
        if c not in FIELD_DESCRIPTIONS and c[:17] != "PCT_TARGET_BASES_":
            log.error(f"Field '{c}' not found in expected Picard fields. Please check your config.")

    headers: Dict[ColumnKey, ColumnDict] = dict()
    for h in table_cols + hidden_table_cols:
        # Set up the configuration for each column
        if h not in headers:
            # Generate a nice string for the column title
            h_title = h
            for s, r in title_cleanup:
                h_title = h_title.replace(s, r)

            # Extract the coverage from the column name
            m = re.match(r".+_(\d+)X", h)
            if m:
                h_tmpl = re.sub(r"_(\d+)X", "_{coverage}X", h)
                descr = FIELD_DESCRIPTIONS.get(h_tmpl, "").replace("{coverage}", m.group(1))
            else:
                descr = FIELD_DESCRIPTIONS.get(h, "")
            if not descr:
                log.warning(f"Field '{h}' not found in FIELD_DESCRIPTIONS, no column description available.")
                descr = ""

            headers[ColumnKey(h)] = {
                "title": h_title.strip().lower().capitalize(),
                "description": descr,
            }
            if h.find("PCT") > -1:
                headers[ColumnKey(h)]["title"] = headers[ColumnKey(h)]["title"]
                headers[ColumnKey(h)]["modify"] = util.multiply_hundred
                headers[ColumnKey(h)]["max"] = 100
                headers[ColumnKey(h)]["suffix"] = "%"

            elif h.find("READS") > -1:
                headers[ColumnKey(h)]["title"] = f"{headers[ColumnKey(h)]['title']}"
                headers[ColumnKey(h)]["shared_key"] = "read_count"

            elif h.find("BASES") > -1:
                headers[ColumnKey(h)]["title"] = f"{headers[ColumnKey(h)]['title']}"
                headers[ColumnKey(h)]["shared_key"] = "base_count"

            # Manual capitilisation for some strings
            headers[ColumnKey(h)]["title"] = headers[ColumnKey(h)]["title"].replace("Pf", "PF").replace("snp", "SNP")

            if h in hidden_table_cols:
                headers[ColumnKey(h)]["hidden"] = True

    return headers


def _add_target_bases(module: BaseMultiqcModule, data: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
    data_clean: Dict[str, Dict[int, float]] = defaultdict(dict)
    max_non_zero_cov = 0
    for s in data:
        for h in data[s]:
            if h.startswith("PCT_TARGET"):
                cov = int(h.replace("PCT_TARGET_BASES_", "")[:-1])
                bases_pct = data[s][h]
                data_clean[s][cov] = bases_pct * 100.0
                if bases_pct > 0 and cov > max_non_zero_cov:
                    max_non_zero_cov = cov

    pconfig = {
        "id": f"{module.anchor}_percentage_target_bases",
        "title": f"{module.name} HSMetrics: percentage of target base pairs",
        "xlab": "Fold coverage",
        "ylab": "Percentage of base pairs",
        "ymax": 100,
        "ymin": 0,
        "xmin": 0,
        "xmax": max_non_zero_cov,
        "tt_label": "<b>{point.x}X</b>: {point.y:.2f}%",
    }
    return {
        "name": "Hybrid-selection target coverage",
        "anchor": f"{module.anchor}_hsmetrics_target_bases",
        "description": "The percentage of all target bases with at least <code>x</code> fold coverage.",
        "plot": linegraph.plot(data_clean, pconfig),
    }


def hs_penalty_plot(module: BaseMultiqcModule, data: Dict[str, Dict[str, Any]]):
    data_clean: Dict[str, Dict[int, float]] = defaultdict(dict)
    any_non_zero = False
    for s in data:
        for h in data[s]:
            if h.startswith("HS_PENALTY"):
                data_clean[s][int(h.lstrip("HS_PENALTY_").rstrip("X"))] = data[s][h]
                if data[s][h] > 0:
                    any_non_zero = True

    pconfig = {
        "id": f"{module.anchor}_hybrid_selection_penalty",
        "title": f"{module.name}: Hybrid-selection penalty",
        "xlab": "Fold coverage",
        "ylab": "Penalty",
        "ymin": 0,
        "xmin": 0,
        "x_decimals": False,
        "tt_label": "<b>{point.x}X</b>: {point.y:.2f}%",
    }

    if any_non_zero:
        return linegraph.plot(data_clean, pconfig)
