#!/usr/bin/env python3

"""
Script to generate a document about Arti RPC objects and methods.

- Extracts a list of methods and objects by running Arti RPC.
- Extracts markdown documentation by running nightly rustdoc to get JSON output.

- Outputs markdown or html.

Requires rustdoc nightly, for json support.
May require specific versions of rustdoc nightly,
since the json format is unstable.

Known to work with nightly from 2024-10-22.
"""

import io
import json
import os
import re
import subprocess
import sys

# What version of rustdoc json do we know?
SUPPORTED_FORMAT_VERSION = 38
# What nightly do we know to provide that version?
KNOWN_GOOD_NIGHTLY = "nightly-2025-01-17"

# Tricky bits, unsolved
#   - Converting links in the md to be correct.
#   - Extracting only relevant portions of the rustdoc.

# Location of the arti checkout.
ARTI_ROOT = os.path.split(os.path.dirname(__file__))[0]

# Where will we find our rustdoc files?
#
# (This assumes we're being run from an arti checkout.)
TARGET_DIR = os.path.join(ARTI_ROOT, "target")

# With respect to where should we generate our links to objects in rustdoc?
RUSTDOC_ROOT = "https://tpo.pages.torproject.net/core/doc/rust/private/"

# A string we'll use to identify RPC-specific sections in a markdown document.
# (Case-insensitive.)
MAGIC_HEADING = "In The Arti RPC System"

# HTML that we'll use to decorate our markdown output.
HTML_HEADER = """\
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Arti RPC: methods and objects</title>
</head>
"""
HTML_FOOTER = """\
</html>
"""

# This isn't installed by default, so try to give a helpful message.
try:
    import marko
    import marko.md_renderer
except ImportError:
    print(
        "You need to install marko so that we can parse and manipulate markdown.",
        file=sys.stderr,
    )
    print("Try `pip3 install marko`.", file=sys.stderr)
    sys.exit(1)

# This is part of this crate, so try using the version here if we can't
# find a system copy.
try:
    import arti_rpc
except ImportError:
    print("arti_rpc not installed. Faking it...", file=sys.stderr)
    pypath = os.path.join(ARTI_ROOT, "python", "arti_rpc", "src")
    sys.path.append(pypath)
    import arti_rpc


def load_methods_via_rpc():
    """
    Connect to Arti via RPC and ask it for a list of RPC methods.

    Return value is as for "x_list_all_rpc_methods".
    """

    try:
        connection = arti_rpc.ArtiRpcConn()
    except arti_rpc.ArtiRpcError as e:
        print(f"Cannot connect to Arti via RPC: {e}")
        print(
            "Is Arti running? Is it built with RPC support? Did you configure it to listen on RPC?"
        )
        sys.exit(1)

    # Get a list of RPC methods.
    methods = connection.session().invoke("arti:x_list_all_rpc_methods")
    del connection  # TODO: Implement a connection.close(); Implement "with connection".
    return methods


def add_additional_methods(methods):
    """
    Add entries for additional methods that are not ordinarily detected
    `by arti:x_list_all_rpc_methods`.
    """
    # This one isn't detected because it applies to all objects.
    # Eventually we may add a better way to solve this.
    methods["methods"]["rpc:release"] = {
        "applies_to_object_types": "*",
        "method_type": "arti_rpcserver::objmap::methods::RpcRelease",
        "output_type": "tor_rpcbase::Nil",
        "update_type": None,
    }


def mdheader(s, anchor):
    """
    Return Text for a markdown header with
    text `s` (including the ##### header prefix), marking it with a given HTML anchor.
    """
    return f"""
<div id="{anchor}">

{s}

</div>

"""


def doc_subject(doc):
    """
    Return everything from `doc` up to the first blank line.

    (By rustdoc convention, this is a summary about the documented item.)
    """
    lines = [s.strip() for s in doc.split("\n")]
    try:
        blank = lines.index("")
        return " ".join(lines[:blank])
    except ValueError:
        return " ".join(lines)


# An object to parse and render markdown
MD = marko.Markdown(renderer=marko.md_renderer.MarkdownRenderer)


def heading_matches(h, s):
    """
    Return true if the markdown Heading object `h` has exactly
    the case-insensitive text `s`.
    """
    if len(h.children) != 1:
        return False
    if h.children[0].get_type() != "RawText":
        return False
    h_text = h.children[0].children
    return h_text.strip().lower() == s.strip().lower()


def extract_md_section(document, magic_heading=MAGIC_HEADING):
    """
    Try to extract the special markdown section marked with `magic_heading`
    from the markdown `document`.

    The section continues until the beginning of a section header
    with an equal or greater priority.

    On success, replace the whole document with the section.

    If the section is not found, leave the document alone.
    """
    new_children = []
    in_section = False
    section_h_level = None

    for elt in document.children:
        is_heading = elt.get_type() == "Heading"
        if not in_section and is_heading:
            if heading_matches(elt, magic_heading):
                in_section = True
                section_h_level = elt.level

        elif in_section and is_heading:
            if elt.level > section_h_level:
                # new heading inside the desired section
                new_children.append(elt)
            else:
                # Started a new section of the same or higher level.
                in_section = False

        elif in_section:
            new_children.append(elt)

    # If we found the section at all, we replace the document.
    if new_children:
        document.children = new_children


def recursively_evaluate(doc, function):
    """
    Apply 'function' to doc, and recursively to every one of its children.

    Does not apply 'function' to strings.
    """
    function(doc)
    if isinstance(doc.children, list):
        for elt in doc.children:
            recursively_evaluate(elt, function)


def fix_rustdoc_links(doc):
    """
    Repair rustdoc links that marko can't otherwise detect.

    The problematic links are those of the form `[foo](bar)`
    where `bar` is a Rust identifier: Although we generate
    link definitions for those `bar` elements, marko "normalizes"
    their link labels, which throws off the parsing later on.
    """
    # Find every link def ("[foo]: https://...") in the document.
    link_defs = dict()

    def find_link_defs(elt):
        if elt.get_type() == "LinkRefDef":
            link_defs[elt.label] = elt.dest

    recursively_evaluate(doc, find_link_defs)

    # For every link, see if it looks like a broken link to a link
    # def, and if so fix its target.  This effectively turns
    # [text](UnnormalizedLabel) into [text](destination),
    # which marko will render as [text][normalizedlabel].
    def fix_link_target(elt):
        if elt.get_type() == "Link":
            dest_orig = elt.dest
            normalized_dest = marko.helpers.normalize_label(dest_orig)
            if (
                not dest_orig.startswith("http")
                and dest_orig not in link_defs
                and normalized_dest in link_defs
            ):

                elt.dest = link_defs[normalized_dest]

    recursively_evaluate(doc, fix_link_target)


def adjust_doc(doc, link_table, outer_h_level=3):
    """
    Modify a markdown document that we've gotten from rustdoc json:

    - If there is a section named with `MAGIC_HEADING`,
      include only that section.

    - Increment or decrement the heading level so that the entire document
      nests within a heading of level outer_h_level.

    - Adjust every link that we find.

    Returns a string rendering of the document.
    """
    parser = marko.parser.Parser()

    # We add a bunch of Rustdoc link definitions to the end of the document,
    # since otherwise marko won't know how to handle them.
    parts = [doc, "\n\n"]
    for k, v in link_table.items():
        parts.append(f"[{k}]: {v}\n")
    doc = "".join(parts)

    document = parser.parse(doc)

    ####
    # Pull out the section that's called MAGIC_HEADING, if any.
    extract_md_section(document)

    ####
    # Adjust heading levels so that they nest within outer_h_level.
    try:
        min_hlevel = min(
            elt.level for elt in document.children if elt.get_type() == "Heading"
        )
        adjustment = outer_h_level + 1 - min_hlevel
    except ValueError:
        # There aren't any headings, so we pick an unworkable adjustment.
        adjustment = "unused"
    for elt in document.children:
        if elt.get_type == "Heading":
            elt.level += adjustment

    ####
    # Make rustdoc links of the form [`foo`](RustId) actually work correctly.
    fix_rustdoc_links(document)

    return MD.render(document)


def un_generic(s):
    """
    Remove all generics from a rust type `s`
    """
    if "<" in s:
        idx = s.find("<")
        return s[:idx]
    else:
        return s


class RustIdent:
    """
    A Rust identifier, taken from a Rust type returned by Arti RPC.
    """

    def __init__(self, ident):
        assert ident is not None
        self.ident = un_generic(ident)

    def crate(self):
        """Return the crate in which this type is declared."""
        # Note: We unconditionally replace _ with - in the crate name.
        # This is valid for Arti crates, which are never named with _,
        # but it isn't valid in general.
        return self.ident.split("::")[0].replace("_", "-")

    def __eq__(self, other):
        return self.ident == other.ident

    def __lt__(self, other):
        return self.ident < other.ident

    def __hash__(self):
        return hash(self.ident)

    def __repr__(self):
        return f"RustIdent({self.ident!r})"

    def __str__(self):
        return self.ident


class RpcUniverse:
    """
    A set of RPC methods, objects, and result types.
    """

    # Fields
    #
    # methods: A map from method name dicts of `applies_to_object_types`,
    #      `method_type`, `output_type`, `universal`, and an optional `update_type`
    # delegations: A map from RustIdent for an RPC Object type to a
    #   list of the RustIdents which it can delegate to.
    # rustdoc: A map from crate to Rustdoc object.
    def __init__(self, method_list):
        self.methods = dict()
        for mname, minfo in method_list["methods"].items():
            universal = "*" in minfo["applies_to_object_types"]
            atypes = [
                RustIdent(t) for t in minfo["applies_to_object_types"] if t != "*"
            ]
            mtype = RustIdent(minfo["method_type"])
            otype = RustIdent(minfo["output_type"])
            if minfo["update_type"] is not None:
                utype = RustIdent(minfo["update_type"])
            else:
                utype = None
            self.methods[mname] = {
                "applies_to_object_types": atypes,
                "method_type": mtype,
                "output_type": otype,
                "update_type": utype,
                "universal": universal,
            }

        self.delegations = dict()
        for t, lst in method_list["delegations"].items():
            self.delegations[RustIdent(t)] = [RustIdent(t2) for t2 in lst]

        self.apply_delegations()
        self.rustdoc = dict()

    def method_types(self):
        """Return a set of every known RPC method"""
        return set(minfo["method_type"] for minfo in self.methods.values())

    def object_types(self):
        """Return a set of every known RPC object type"""
        s = set()
        for minfo in self.methods.values():
            s.update(minfo["applies_to_object_types"])
        return s

    def output_types(self):
        """Return a set of every known type that can be output from an RPC method"""
        return set(minfo["output_type"] for minfo in self.methods.values())

    def update_types(self):
        """Return a set of every known type that can be an update from an RPC method."""
        return set(
            minfo["update_type"]
            for minfo in self.methods.values()
            if minfo["update_type"] is not None
        )

    def all_types(self):
        """Return a set of all rust types that are relevant for RPC."""
        s = self.method_types()
        s.update(self.object_types())
        s.update(self.output_types())
        s.update(self.update_types())
        return s

    def relevant_crates(self):
        """Return a list of crates that contain a type relevant to RPC."""
        return set(ident.crate() for ident in self.all_types())

    def apply_delegations(self):
        """
        Adjust the `applies_to_object_types` field for every method,
        to take into account object delegation.
        """
        for minfo in self.methods.values():
            to_add = []
            applies_list = minfo["applies_to_object_types"]
            for applies_to in applies_list:
                if applies_to in self.delegations:
                    to_add.extend(self.delegations[applies_to])

            for add in to_add:
                if add not in applies_list:
                    applies_list.append(add)

    def build_and_load_rustdoc(self):
        """
        Build and load the rustdoc for every relevant crate.
        """
        for crate in self.relevant_crates():
            build_rustdoc_json(crate)
            path = os.path.join(TARGET_DIR, "doc", crate.replace("-", "_") + ".json")
            with open(path, "r", encoding="UTF-8") as f:
                doc = json.load(f)
                self.rustdoc[crate] = Rustdoc(doc)

    def get_methods_for_obj(self, ident):
        """
        Return a generator over every method that can apply to the object `ident`.
        """
        for mname, minfo in self.methods.items():
            if ident in minfo["applies_to_object_types"]:
                yield mname

    def get_doc(self, ident):
        """
        Return the rustdoc for the object `ident`.
        """
        return self.rustdoc[ident.crate()].get_doc(ident)

    def get_link_table(self, ident):
        """
        Return a link table to interpret links in the rustdoc
        for the object `ident`.
        """
        return self.rustdoc[ident.crate()].get_link_table(ident)

    def get_rustdoc_url(self, ident):
        """
        Return a canonical rustdoc URL (within RUSTDOC_ROOT) for `ident`.
        """
        return self.rustdoc[ident.crate()].get_rustdoc_url(ident)

    def emit_method_index(self, f):
        """
        Write to `f` an index of all known RPC methods.
        """
        f.write(mdheader("### Method index", "idx:methods"))
        for m, obj in sorted(self.methods.items()):
            m_str = m
            m_anchor = "method:" + m_str
            m_ident = obj["method_type"]
            m_summary = doc_subject(self.get_doc(m_ident))
            f.write(f"- [`{m_str}`](#{m_anchor}) — {m_summary}\n")
        f.write("\n\n")

    def emit_single_method(self, f, m_str, obj):
        """
        Write to `f` the documentation for the single method m_str, whose "minfo"
        object is `obj`.
        """
        m_ident = obj["method_type"]
        m_doc = self.get_doc(m_ident)
        link_table = self.get_link_table(m_ident)
        m_summary = doc_subject(m_doc)
        m_anchor = "method:" + m_str
        header = f"### `{m_str}` — {m_summary}"
        f.write(mdheader(header, m_anchor))
        f.write(adjust_doc(m_doc, link_table))
        f.write("\n\n")

        rustdoc_url = self.get_rustdoc_url(m_ident)
        f.write(f"- [Rustdoc]({rustdoc_url})\n")
        if self.rustdoc[m_ident.crate()].type_has_no_fields(m_ident):
            f.write("- (Takes no parameters)\n")
        else:
            f.write(f"- [Parameters]({rustdoc_url}#fields)\n")
        otype = obj["output_type"]
        otype_url = self.get_rustdoc_url(otype)
        f.write(f"- Returns [`{otype}`]({otype_url})\n")
        utype = obj["update_type"]
        if utype is not None:
            utype_url = self.get_rustdoc_url(utype)
            f.write(f"- Yields incremental updates of [`{utype}`]({utype_url})\n")
        else:
            f.write("- (No incremental updates)\n")

        if obj["universal"]:
            f.write("- **Implemented by all objects**\n")
        else:
            f.write("- **Implemented by**\n")
            for itype in obj["applies_to_object_types"]:
                itype_url = self.get_rustdoc_url(itype)
                anchor = f"object:{itype}"
                f.write(f"    - [`{itype}`](#{anchor}) ([Rustdoc]({itype_url}))\n")

        f.write("\n\n")

    def emit_methods(self, f):
        """
        Write to `f` all per-method documentation.
        """
        for m, obj in sorted(self.methods.items()):
            self.emit_single_method(f, m, obj)

    def emit_object_index(self, f):
        """
        Write to `f` an index of all known RPC object types.
        """
        f.write(mdheader("### Object index", "idx:objects"))

        for o_name in sorted(self.object_types()):
            o_anchor = f"object:{o_name}"
            o_summary = doc_subject(self.get_doc(o_name))
            f.write(f"- [`{o_name}`](#{o_anchor}) — {o_summary}\n")
        f.write("\n\n")

    def emit_single_object(self, f, o_name):
        """
        Write to `f` the documentation for a single RPC object whose type is o_name.
        """
        o_anchor = f"object:{o_name}"
        o_doc = self.get_doc(o_name)
        link_table = self.get_link_table(o_name)
        o_summary = doc_subject(o_doc)
        header = f"### `{o_name}` — {o_summary}"
        f.write(mdheader(header, o_anchor))
        f.write(adjust_doc(o_doc, link_table))
        f.write("\n\n")
        rustdoc_url = self.get_rustdoc_url(o_name)
        f.write(f"- [Rustdoc]({rustdoc_url})\n")
        f.write("- **Implements methods**\n")

        for m_name in sorted(list(self.get_methods_for_obj(o_name))):
            m_anchor = "method:" + m_name
            f.write(f"   - [`{m_name}`](#{m_anchor})\n")
        f.write("\n\n")

    def emit_objects(self, f):
        """
        Write all per-object documentation to `f`.
        """
        for o_name in sorted(self.object_types()):
            self.emit_single_object(f, o_name)

    def emit_docs(self, f):
        """
        Write all markdown documentation to `f`
        """
        f.write(mdheader("## Methods", "methods"))
        self.emit_method_index(f)
        self.emit_methods(f)
        f.write("\n----\n")

        f.write(mdheader("## Objects", "objects"))
        self.emit_object_index(f)
        self.emit_objects(f)


have_warned_about_json_version = False


def warn_about_json_version(actual_version):
    """
    If we haven't previously done so, warn the user that their nightly
    is generating an unrecognized version of rustdoc json.
    """
    global have_warned_about_json_version
    if not have_warned_about_json_version:
        have_warned_about_json_version = True
        msg = f"""
              WARNING: Rustdoc json is in format version {actual_version},
              but this tool expects {SUPPORTED_FORMAT_VERSION}.
              If this fails, you might need to fix the tool,
              or switch to {KNOWN_GOOD_NIGHTLY}.
        """
        msg = re.sub(r"\s+", " ", msg).strip()
        print(msg)


class Rustdoc:
    """
    The rustdoc for a single crate.
    """

    def __init__(self, json_doc):
        actual_version = json_doc.get("format_version")
        if actual_version != SUPPORTED_FORMAT_VERSION:
            warn_about_json_version(actual_version)

        self.doc = json_doc

        self.idx_by_ident = dict()
        for idx, obj in self.doc["paths"].items():
            path = "::".join(obj["path"])
            self.idx_by_ident[RustIdent(path)] = idx

    def get_doc(self, ident):
        """
        Find the string holding documentation for a single identifier within this crate.
        """
        obj = self.doc["index"][self.idx_by_ident[ident]]
        return obj["docs"]

    def get_link_table(self, ident):
        """
        Return a dict mapping reference link ID to corresponding URLs for
        the documentation of `ident`.
        """
        obj = self.doc["index"][self.idx_by_ident[ident]]
        links = obj["links"]
        table = {}
        for content, idx in links.items():
            table[content] = self.get_rustdoc_url_by_idx(idx)
        return table

    def get_rustdoc_url(self, ident):
        """
        Return a URL for the documentation for an identifier within this crate,
        relative to a rustdoc installation at RUSTDOC_ROOT.
        """
        idx = self.idx_by_ident[ident]
        return self.get_rustdoc_url_by_idx(idx)

    def get_rustdoc_url_by_idx(self, idx):
        """
        Return a URL according to a rustdoc "index".
        """
        obj = self.doc["paths"][str(idx)]
        kind = obj["kind"]
        path = obj["path"][:]
        path[-1] = f"{kind}.{path[-1]}.html"
        return RUSTDOC_ROOT + "/".join(path)

    def type_has_no_fields(self, ident):
        """Return true if `ident` is a type that definitely has no fields."""
        obj = self.doc["index"][self.idx_by_ident[ident]]
        try:
            fields = obj["inner"]["struct"]["kind"]["plain"]["fields"]
        except KeyError:
            return False

        return len(fields) == 0


def build_rustdoc_json(crate):
    """
    Use `cargo +nightly rustdoc` to build the json rustdoc for a single crate.
    """
    args = [
        "cargo",
        "+nightly",
        "rustdoc",
        "--quiet",
        "--output-format",
        "json",
        "--all-features",
        "-Zunstable-options",
        "-p",
        crate,
        "--",
        "--document-private-items",
    ]
    subprocess.run(args, check=True)


def run(output, fmt="md"):
    """Process input from arti and rustdoc, and write a combined document
    to `output`.

    Write markdown if `fmt` is "md", and HTML if `fmt` is "html".
    """
    print(
        "== STEP 1: Asking Arti RPC for a list of types and methods.", file=sys.stderr
    )

    methods = load_methods_via_rpc()
    add_additional_methods(methods)
    universe = RpcUniverse(methods)

    n_methods = len(methods["methods"])
    n_types = len(universe.all_types())
    print(
        f"Found {n_methods} methods and {n_types} relevant Rust types.", file=sys.stderr
    )

    print("== STEP 2: Extracting rustdoc as json", file=sys.stderr)

    universe.build_and_load_rustdoc()

    print("== STEP 3: Emitting markdown", file=sys.stderr)
    if fmt == "md":
        md_output = output
    else:
        md_output = io.StringIO()

    universe.emit_docs(md_output)

    if fmt == "html":
        print("== STEP 4: Converting to HTML", file=sys.stderr)
        md_html = marko.Markdown()
        parsed = md_html.parse(md_output.getvalue())
        output.write(HTML_HEADER)
        output.write(md_html.render(parsed))
        output.write(HTML_FOOTER)


def main(args):
    """Invoke rpc-docs-tool using the command-line arguments in args.

    (Make sure to omit sys.argv[0], or you will overwrite "rpc-docs-tool")
    """
    import argparse

    parser = argparse.ArgumentParser(
        prog="rpc-docs-tool", description="Generate RPC method docs"
    )
    parser.add_argument("output", type=argparse.FileType("w"))
    parser.add_argument("--format", default=None, choices=["md", "html"], dest="fmt")
    args = parser.parse_args(args)

    # Decide what file format to use.
    if args.fmt is None:
        fname = args.output.name
        extension = os.path.splitext(fname)[1]
        print(extension)
        if extension in [".md", ".html"]:
            fmt = extension[1:]
        else:
            fmt = "md"  # default
    else:
        fmt = args.fmt

    run(args.output, fmt)


if __name__ == "__main__":
    main(sys.argv[1:])
