#!/usr/bin/env python3
"""This script parses the documentation provided by LAPACK to generate a header
file declaring prototypes for BLAS and LAPACK functions and definitions of some
type-generic wrappers (e.g., four overloads for GEMM based on the input types).

To use this file: running

    ./generate_lapack_templates --help

prints descriptions of the different command line arguments. In brief: this file
reads the three text files in lapack_templates_data/ to determine which wrappers
it should generate. It generates the wrappers by parsing LAPACK's documentation:
the location of the library must be provided (with the --lapack-root-directory
argument).

To add more files: simply put the name of the subroutine (without the type
prefix, i.e., gemv instead of dgemv) in lapack_templates_data/blas_subroutines
(for BLAS routines) or lapack_templates_data/lapack_subroutines (for LAPACK
routines). If any subroutines need to be run with floating point exceptions
disabled then also add them to lapack_templates_data/fp_exceptions.

"""

import argparse
import glob
import os
import os.path as osp
import re
import sys
import time


def fortran_to_cpp_type(regex):
    """Return a C++ type corresponding to the provided regular expression that
    matches a FORTRAN type.
    """
    # TODO surely there is a better way to do this...
    if isinstance(regex, str):
        if regex == 'CHARACTER*1':
            return 'char'
        if regex == 'CHARACTER':
            return 'char'
        if regex == 'INTEGER':
            return 'dealii::types::blas_int'
        # TODO the FORTRAN standard states that LOGICAL (default KIND) should have
        # the same storage size as INTEGER and REAL. Its not clear what happens
        # when we have 64 bit BLAS and those are no longer equal, so this choice
        # might not work on big-endian machines.
        if regex == 'LOGICAL':
            return 'dealii::types::blas_int'
        if regex == 'REAL':
            return 'float'
        if regex == 'DOUBLE':
            return 'double'
        if regex == 'DOUBLE PRECISION':
            return 'double'
        if regex == 'COMPLEX':
            return 'std::complex<float>'
        if regex == 'COMPLEX*16':
            return 'std::complex<double>'
    else:
        if regex.pattern == 'CHARACTER\\*1':
            return 'char'
        if regex.pattern == 'CHARACTER':
            return 'char'
        if regex.pattern == 'INTEGER':
            return 'dealii::types::blas_int'
        # TODO the FORTRAN standard states that LOGICAL (default KIND) should have
        # the same storage size as INTEGER and REAL. Its not clear what happens
        # when we have 64 bit BLAS and those are no longer equal, so this choice
        # might not work on big-endian machines.
        if regex.pattern == 'LOGICAL':
            return 'dealii::types::blas_int'
        if regex.pattern == 'REAL':
            return 'float'
        if regex.pattern == 'DOUBLE (PRECISION)?':
            return 'double'
        if regex.pattern == 'COMPLEX[^*]':
            return 'std::complex<float>'
        if regex.pattern == 'COMPLEX\\*16':
            return 'std::complex<double>'

    assert False
    return None


def field_code_to_cpp(string):
    """Convert a field code (s, d, c, or z) to a C++ type.
    """
    if string == 's':
        return 'float'
    if string == 'd':
        return 'double'
    if string == 'c':
        return 'std::complex<float>'
    if string == 'z':
        return 'std::complex<double>'
    assert False
    return None


def field_code_to_numeric_ranking(string):
    """Convert field codes (s, d, c, or z) to a numerical ranking that puts them in
    that order (rather than alphabetical).
    """
    if string == 's':
        return 0
    if string == 'd':
        return 1
    if string == 'c':
        return 2
    if string == 'z':
        return 3
    assert False
    return None


def parse_argument_constness(file_name):
    """Read the given file to determine which arguments are const (i.e.,
    INTENT(IN)). Since FORTRAN-77 does not support this directly we examine the
    documentation and consider anything labeled as 'parameter[in]' as const.
    """
    store_line = False
    lines = list()
    with open(file_name, 'r') as handle:
        for line in handle:
            if "Arguments" in line:
                store_line = True
            elif "Authors" in line:
                break
            if store_line and "param" in line:
                lines.append(line[3:-1])

    const_parameters = list()
    for line in lines:
        # The LAPACK documentation might not be synchronized with the
        # implementation, but this is the best we can do
        if "[in]" in line:
            const_parameters.append(line.split(' ')[-1])
    return const_parameters


def parse_arguments(file_name):
    """Parse the input arguments to a subroutine by reading its documentation.
    """
    store_line = False
    lines = list()
    with open(file_name, 'r') as handle:
        for line in handle:
            if "Definition" in line:
                store_line = True
            elif "Purpose" in line:
                assert store_line
                break  # we are done

            if store_line:
                # remove the prepended '*'
                lines.append(line[1:].strip())

    # remove lines before 'SUBROUTINE'
    line = lines[0]
    while ("SUBROUTINE " not in line and "FUNCTION" not in line):
        lines.pop(0)
        line = lines[0]

    # remove blanks and lines with dots
    lines = list(filter(lambda u: ".." not in u and u != "", lines))

    # get all subroutine arguments in one set of parentheses
    while ')' not in lines[0]:
        next_line = lines[1].replace('$', ' ')
        lines[0] += ' ' + next_line
        lines.pop(1)

    # join lines that start with '$'s or when the previous line ends in ','
    line_n = 0
    while line_n < len(lines):
        if '$' in lines[line_n]:
            assert line_n > 0
            lines[line_n - 1] += ' ' + lines[line_n][1:].strip()
            lines.pop(line_n)
        elif line_n > 0 and lines[line_n - 1][-1] == ',':
            lines[line_n - 1] += ' ' + lines[line_n][1:].strip()
            lines.pop(line_n)
        else:
            line_n += 1

    # since everything is passed by pointer in Fortran, go ahead and delete
    # array descriptions
    def remove_text_in_parentheses(string):
        """Remove all text in parentheses and parentheses themselves.
        """
        left_index = string.find('(')
        if left_index != -1:
            right_index = string.find(')')
            assert right_index != -1
            return remove_text_in_parentheses(string[:left_index]
                                              + string[right_index + 1:])
        return string

    return [lines[0]] + list(map(remove_text_in_parentheses, lines[1:]))


class FortranRoutine:
    """Class storing information about a FORTRAN subroutine or function with methods
    to print it. Has the following members:

    self.file_name : original provided file name

    self.disable_floating_point_exceptions : whether or not to disable floating
    point exceptions for the duration of the routine call

    self.lines : Lightly edited (preceding '*'s removed, '$' lines joined, etc.)
    lines from the documentation of the given function that completely describe
    its function arguments and, if relevant, return type.

    self.const_arguments : List of function arguments that correspond to const
    pointers: this is derived from 'intent[in]' statements in the FORTRAN
    function documentation.

    self.return_type : the return type of the subroutine (always void) or
    function (never void).

    self.field_code : One-letter code ('s', 'd', 'c', or 'z') describing the
    type of floating point inputs the function expects.

    self.general_name : string containing the general name (i.e., name without
    field code) of the subroutine or function.

    self.macro_name : string containing a macro that will expand to the correct
    symbol name.

    self.arguments_by_type : A dictionary (with regular expression keys) mapping
    regular expressions matching FORTRAN types to input arguments.
    """
    def __init__(self, file_name, disable_floating_point_exceptions=False):
        self.file_name = file_name
        self.disable_floating_point_exceptions = disable_floating_point_exceptions
        self.lines = parse_arguments(file_name)
        self.const_arguments = parse_argument_constness(file_name)

        assert "SUBROUTINE" in self.lines[0] or "FUNCTION" in self.lines[0]
        if "SUBROUTINE" in self.lines[0]:
            self.return_type = "void"
            name = self.lines[0].split()[1]
        else:
            function_start = self.lines[0].find("FUNCTION")
            self.return_type = fortran_to_cpp_type(self.lines[0][:function_start].strip())
            words = self.lines[0].split(' ')
            while words[0] != "FUNCTION":
                words.pop(0)
            name = words[1]
        parentheses_location = name.find('(')
        if parentheses_location != -1:
            name = name[:parentheses_location]

        self.field_code = name[0].lower()
        self.general_name = name[1:].strip().lower()

        # we need a different macro for mangling FORTRAN names with underscores:
        fortran_name = self.field_code + self.general_name
        if "_" in fortran_name:
            self.macro_name = "DEAL_II_FORTRAN_MANGLE_UNDERSCORE({}, {})".format(
                fortran_name, fortran_name.upper())
        else:
            self.macro_name = "DEAL_II_FORTRAN_MANGLE({}, {})".format(
                fortran_name, fortran_name.upper())

        self.arguments_by_type = dict()
        self.arguments_by_type[re.compile("CHARACTER")] = list()
        self.arguments_by_type[re.compile("LOGICAL")] = list()
        self.arguments_by_type[re.compile("INTEGER")] = list()
        self.arguments_by_type[re.compile("REAL")] = list()
        self.arguments_by_type[re.compile("DOUBLE (PRECISION)?")] = list()
        self.arguments_by_type[re.compile("COMPLEX[^*]")] = list()
        self.arguments_by_type[re.compile("COMPLEX\\*16")] = list()
        # TODO one day we should support EXTERNAL

        for line in self.lines:
            found_type = False
            if line.split() == ["IMPLICIT", "NONE"]:
                continue
            if "SUBROUTINE" in line or "FUNCTION" in line:
                start = line.find('(') + 1
                end = line.find(')')
                self.arguments = [u.strip() for u in line[start:end].strip().split(',')]
                found_type = True
            else:
                for key in self.arguments_by_type:
                    match = key.match(line)
                    if match is not None:
                        self.arguments_by_type[key] += [u.strip() for u in
                                                        line[match.end():].split(',')]
                        found_type = True
            if not found_type:
                if "EXTERNAL" in line:
                    raise NotImplementedError(
                        "The parser cannot handle EXTERNAL declarations since "
                        "FORTRAN does not provide type information in that case.")

            assert found_type

    def is_subroutine(self):
        """Return whether or not the routine is a SUBROUTINE. It might also be a
        FUNCTION, which is not yet implemented.
        """
        for line in self.lines:
            if "SUBROUTINE" in line:
                return True
        return False

    def print_prototype(self, file=sys.stdout):
        """Print a deal.II-compatible function prototype.
        """
        print("", file=file)

        print(self.return_type, self.macro_name, "(",
              sep=os.linesep, end='', file=file)
        for index, argument in enumerate(self.arguments):
            found_type = False
            if argument in self.const_arguments:
                print("const ", end='', file=file)
            for (key, arguments) in self.arguments_by_type.items():
                if argument in arguments:
                    print(fortran_to_cpp_type(key), end='', file=file)
                    found_type = True

            if not found_type:
                import pdb
                pdb.set_trace()
            print(" *" + argument.lower(), end='', file=file)

            if index < len(self.arguments) - 1:
                print(", ", end='', file=file)
        print(");", file=file)

    def print_base_case_template(self, file=sys.stdout):
        """Print a deal.II-compatible template base case. Each floating point type is
        converting to a template parameter. This template will always throw an
        exception, but is useful as a base case.

        If there are no floating point arguments to the function (e.g., dlamch)
        then nothing is printed."""
        excluded_keys = [re.compile("CHARACTER"), re.compile("LOGICAL"),
                         re.compile("INTEGER")]
        template_n = 1
        argument_to_template_type = dict()

        # assign new template types to floating point arguments:
        for argument in self.arguments:
            found_type = False
            for (key, arguments) in self.arguments_by_type.items():
                if argument in arguments:
                    if key in excluded_keys:
                        argument_to_template_type[argument] = fortran_to_cpp_type(key)
                    else:
                        argument_to_template_type[argument] = "number" + str(template_n)
                        template_n += 1
                    found_type = True
            assert found_type

        # we number templates from 1, so if we are still at 1 there are no
        # template arguments. We add one that will also be the return type. To
        # the best of my knowledge this only happens with DLAMCH and SLAMCH.
        if template_n == 1:
            n_template_types = 1
            assert self.return_type != "void"
        else:
            n_template_types = template_n - 1

        print("\n\n", file=file)
        # print template signature
        print("template <", end='', file=file)
        for template_n in range(1, n_template_types + 1):
            if template_n < n_template_types:
                print("typename number" + str(template_n) + ", ", end='', file=file)
            else:
                print("typename number" + str(template_n) + ">", file=file)
        print("inline ", file=file)
        if self.return_type == "void":
            print(self.return_type, file=file)
        # since this function will just throw an exception whenever its called
        # the return type is not important. This also covers the DLAMCH-like
        # case.
        else:
            print("number1", file=file)

        print(self.general_name, file=file)
        print("(", end='', file=file)

        # print types but not argument names
        for index, argument in enumerate(self.arguments):
            if argument in self.const_arguments:
                print("const ", end='', file=file)
            print(argument_to_template_type[argument], end='', file=file)

            print(" *", end='', file=file)

            if index < len(self.arguments) - 1:
                print(", ", end='', file=file)

        print(")", file=file)
        print("{", file=file)
        print("  Assert(false, ExcNotImplemented());", file=file)
        if self.return_type == "void":
            pass
        else:
            print("  return number1();", file=file)
        print("}", file=file)

    def print_overload_definition(self, file=sys.stdout):
        """Print a definition of an overload for the present subroutine.
        """
        excluded_keys = [re.compile("CHARACTER"), re.compile("LOGICAL"),
                         re.compile("INTEGER")]
        # We only specialize the template if there are no function arguments
        # with which we could deduce the correct function
        need_template = True
        for argument in self.arguments:
            for (key, arguments) in self.arguments_by_type.items():
                if argument in arguments and key not in excluded_keys:
                    need_template = False
                    break

        print("\n\n", file=file)
        if need_template:
            print("template <>", file=file)
        print("inline {}".format(self.return_type), file=file)
        print(self.general_name, file=file)
        print("(", end='', file=file)

        # print arguments
        for index, argument in enumerate(self.arguments):
            found_type = False
            if argument in self.const_arguments:
                print("const ", end='', file=file)
            for (key, arguments) in self.arguments_by_type.items():
                if argument in arguments:
                    print(fortran_to_cpp_type(key), end='', file=file)
                    found_type = True

            assert found_type
            print(" *" + argument.lower(), end='', file=file)

            if index < len(self.arguments) - 1:
                print(", ", end='', file=file)
        print(")", file=file)

        print("{", file=file)
        print("#ifdef DEAL_II_WITH_LAPACK", file=file)
        if self.disable_floating_point_exceptions:
            print("  // Netlib and Atlas Lapack perform floating point ", end='', file=file)
            print("tests (e.g. divide-by-zero)", file=file)
            print("  // within calls to some functions, which cause ", end='', file=file)
            print("floating point exceptions to be thrown", file=file)
            print("  // (at least in debug mode). Therefore, we wrap the ", end='', file=file)
            print("calls into the following ", file=file)
            print("  // code to suppress the exception.", file=file)
            print("#ifdef DEAL_II_HAVE_FP_EXCEPTIONS", file=file)
            print("fenv_t fp_exceptions;", file=file)
            print("feholdexcept(&fp_exceptions);", file=file)
            print("#endif // DEAL_II_HAVE_FP_EXCEPTIONS", file=file)
        # call the Fortran function
        if self.return_type == "void":
            print("  ", end='', file=file)
        else:
            print("  return ", end='', file=file)
        print(self.macro_name + "(", end='', file=file)
        for index, argument in enumerate(self.arguments):
            print(argument.lower(), end='', file=file)

            if index < len(self.arguments) - 1:
                print(", ", end='', file=file)
        print(");", file=file)

        if self.disable_floating_point_exceptions:
            print("#ifdef DEAL_II_HAVE_FP_EXCEPTIONS", file=file)
            print("fesetenv(&fp_exceptions);", file=file)
            print("#endif // DEAL_II_HAVE_FP_EXCEPTIONS", file=file)

        print("#else", file=file)
        for argument in self.arguments:
            print("  (void){};".format(argument.lower()), file=file)
        print("  Assert(false, LAPACKSupport::ExcMissing(\"{}\"));"
              .format(self.field_code + self.general_name), file=file)
        if self.return_type != "void":
            print("  return {}();".format(self.return_type), file=file)
        print("#endif // DEAL_II_WITH_LAPACK", file=file)
        print("}", file=file)

    def __repr__(self):
        """Representation of the object.
        """
        return self.field_code + self.general_name


def main():
    """Determine the subroutines and then print needed declarations and definitions
    to the header.
    """
    parser = argparse.ArgumentParser(
        description=("Automatically generate deal.II-compatible template BLAS "
                     + "and LAPACK wrappers by parsing the documentation provided "
                     + " by each library. The best way to get this documentation "
                     + "is to clone the LAPACK development sources available at "
                     + "https://github.com/Reference-LAPACK and then pass the "
                     + "root directory location to this program with the "
                     + "--lapack-root-directory option."))
    lapack_subroutines = list()
    blas_subroutines = list()
    fp_guarded_subroutines = list()
    parser.add_argument("--lapack-root-directory", type=str,
                        help="Path of the LAPACK source root directory.")
    parser.add_argument("--lapack-subroutines", type=str,
                        help="file listing which LAPACK subroutines we should "
                        "wrap. Defaults to 'lapack_templates_data/lapack_subroutines' "
                        "in the current directory.")
    parser.add_argument("--blas-subroutines", type=str,
                        help="file listing which BLAS subroutines we should "
                        "wrap. Defaults to 'lapack_templates_data/blas_subroutines' "
                        "in the current directory.")
    parser.add_argument("--fp-guarded-subroutines", type=str,
                        help="Subroutines which should be wrapped with "
                        "disabling and enabling of floating point exceptions. "
                        "Defaults to 'lapack_templates_data/fp_subroutines' in "
                        "the current directory.")
    parser.add_argument("--output-file", type=str,
                        help="File to which we write. Defaults to "
                        "'lapack_templates.h' in the current directory.")
    args = vars(parser.parse_args())

    if args["lapack_root_directory"] is None:
        raise ValueError("The path to the default LAPACK directory must be "
                         "specified.")

    if args["blas_subroutines"] is None:
        args["blas_subroutines"] = "lapack_templates_data/blas_subroutines"
    if args["lapack_subroutines"] is None:
        args["lapack_subroutines"] = "lapack_templates_data/lapack_subroutines"
    if args["fp_guarded_subroutines"] is None:
        args["fp_guarded_subroutines"] = "lapack_templates_data/fp_subroutines"
    if args["output_file"] is None:
        args["output_file"] = "lapack_templates.h"

    with open(args["lapack_subroutines"]) as lapack_handle:
        for line in lapack_handle:
            lapack_subroutines.append(line.strip())
    with open(args["blas_subroutines"]) as blas_handle:
        for line in blas_handle:
            blas_subroutines.append(line.strip())
    with open(args["fp_guarded_subroutines"]) as fp_handle:
        for line in fp_handle:
            fp_guarded_subroutines.append(line.strip())

    #
    # Parse subroutines:
    #
    routines = list()
    all_general_routines = blas_subroutines + lapack_subroutines
    for [general_routines, routine_directory] in [
            [lapack_subroutines, args["lapack_root_directory"] + "/SRC/"],
            # dlamch and slamch are in the INSTALL directory
            [lapack_subroutines, args["lapack_root_directory"] + "/INSTALL/"],
            [blas_subroutines, args["lapack_root_directory"] + "/BLAS/SRC/"]]:
        for general_routine in general_routines:
            globs = glob.glob(routine_directory + "*" + general_routine + ".f")
            for file_name in globs:
                try:
                    routine_name = osp.basename(file_name)[1:-2]
                    guard_fp = routine_name in fp_guarded_subroutines
                    routines.append(FortranRoutine(file_name, guard_fp))
                except NotImplementedError:
                    # Not a lot we can do. If its in the allowed list we will fail
                    # again later anyway
                    continue
            routines.sort(key=lambda u: u.general_name)

    #
    # Collect subroutines together that have the same general name
    #
    organized_routines = list()
    organized_routines.append([routines[0]])
    for routine in routines[1:]:
        if routine.general_name == organized_routines[-1][0].general_name:
            organized_routines[-1].append(routine)
        else:
            organized_routines.append([routine])
    #
    # sort by type (s, d, c z)
    #
    for routine_list in organized_routines:
        routine_list.sort(key=lambda a: field_code_to_numeric_ranking(a.field_code))

    #
    # Generate output
    #
    with open(args["output_file"], 'w') as output:
        print("// ------------------------------------------------------------------------",
              file=output)
        print("//",
              file=output)
        year = time.ctime().split(' ')[-1]
        print("// SPDX-License-Identifier: LGPL-2.1-or-later",
              file=output)
        print("// Copyright (C) 2005 - {} by the deal.II authors".format(year),
              file=output)
        print("//",
              file=output)
        print("// This file is part of the deal.II library.",
              file=output)
        print("//",
              file=output)
        print("// Part of the source code is dual licensed under Apache-2.0 WITH",
              file=output)
        print("// LLVM-exception OR LGPL-2.1-or-later. Detailed license information",
              file=output)
        print("// governing the source code and code contributions can be found in",
              file=output)
        print("// LICENSE.md and CONTRIBUTING.md at the top level directory of deal.II.",
              file=output)
        print("//",
              file=output)
        print("// ------------------------------------------------------------------------------",
              file=output)

        print("", file=output)
        print("", file=output)
        print("#ifndef dealii_lapack_templates_h", file=output)
        print("#define dealii_lapack_templates_h", file=output)
        print("", file=output)
        print("#include <deal.II/base/config.h>", file=output)
        print("", file=output)
        print("#include <deal.II/lac/lapack_support.h>", file=output)
        print("", file=output)
        print("#ifdef DEAL_II_HAVE_FP_EXCEPTIONS", file=output)
        print("#  include <cfenv>", file=output)
        print("#endif", file=output)

        print("", file=output)
        print("extern \"C\"", file=output)
        print("{", file=output)
        for routine_list in organized_routines:
            if routine_list[0].general_name in all_general_routines:
                for routine in routine_list:
                    routine.print_prototype(file=output)
        print("}", file=output)

        print("", file=output)
        print("DEAL_II_NAMESPACE_OPEN", file=output)
        print("", file=output)
        for routine_list in organized_routines:
            if routine_list[0].general_name in all_general_routines:
                routine_list[0].print_base_case_template(file=output)
                for routine in routine_list:
                    routine.print_overload_definition(file=output)

        print("DEAL_II_NAMESPACE_CLOSE", file=output)
        print("", file=output)
        print("#endif", file=output)


if __name__ == '__main__':
    main()
