#!/usr/bin/env python

# pripamtopng
#
# Python Raster Image PAM to PNG

from __future__ import print_function

import sys

from array import array

import png


def read_pam_header(infile):
    """
    Read (the rest of a) PAM header.
    `infile` should be positioned immediately after the initial 'P7' line
    (at the beginning of the second line).
    Returns are as for `read_pnm_header`.
    """

    # Unlike PBM, PGM, and PPM, we can read the header a line at a time.
    header = dict()
    while True:
        line = infile.readline().strip()
        if line == b'ENDHDR':
            break
        if not line:
            raise EOFError('PAM ended prematurely')
        if line[0] == b'#':
            continue
        line = line.split(None, 1)
        key = line[0]
        if key not in header:
            header[key] = line[1]
        else:
            header[key] += b' ' + line[1]

    required = [b'WIDTH', b'HEIGHT', b'DEPTH', b'MAXVAL']
    required_str = b', '.join(required).decode('ascii')
    result = []
    for token in required:
        if token not in header:
            raise png.Error('PAM file must specify ' + required_str)
        try:
            x = int(header[token])
        except ValueError:
            raise png.Error(required_str + ' must all be valid integers')
        if x <= 0:
            raise png.Error(required_str + ' must all be positive integers')
        result.append(x)

    return ('P7',) + tuple(result)


def read_pnm_header(infile):
    """
    Read a PNM header, returning (format,width,height,depth,maxval).
    Also reads a PAM header (by using a helper function).
    `width` and `height` are in pixels.
    `depth` is the number of channels in the image;
    for PBM and PGM it is synthesized as 1, for PPM as 3;
    for PAM images it is read from the header.
    `maxval` is synthesized (as 1) for PBM images.
    """

    # Generally, see http://netpbm.sourceforge.net/doc/ppm.html
    # and http://netpbm.sourceforge.net/doc/pam.html

    supported = (b'P5', b'P6', b'P7')

    # Technically 'P7' must be followed by a newline, so by using
    # rstrip() we are being liberal in what we accept.  I think this
    # is acceptable.
    type = infile.read(3).rstrip()
    if type not in supported:
        raise NotImplementedError('file format %s not supported' % type)
    if type == b'P7':
        # PAM header parsing is completely different.
        return read_pam_header(infile)

    # Expected number of tokens in header (3 for P4, 4 for P6)
    expected = 4
    pbm = (b'P1', b'P4')
    if type in pbm:
        expected = 3
    header = [type]

    # We must read the rest of the header byte by byte because
    # the final whitespace character may not be a newline.
    # Of course all PNM files in the wild use a newline at this point,
    # but we are strong and so we avoid
    # the temptation to use readline.
    def getc():
        c = infile.read(1)
        if not c:
            raise png.Error('premature EOF reading PNM header')
        return c


    c = getc()
    while True:
        # skip comments
        while b'#' <= c <= b'#':
            while c not in b'\n\r':
                c = getc()
            c = getc()
        # Skip whitespace that precedes a token.
        while c.isspace():
            c = getc()
        if not c.isdigit():
            raise png.Error('unexpected byte %r found in header' % c)
        # According to the specification it is legal to have comments
        # that appear in the middle of a token.
        # I've never seen it; and,
        # it's a bit awkward to code good lexers in Python (no goto).
        # So we break on such cases.
        token = b''
        while c.isdigit():
            token += c
            c = getc()
        # All "tokens" are decimal integers, so convert them here.
        header.append(int(token))
        if len(header) == expected:
            break
    if not c.isspace():
        raise png.Error('expected header to end with whitespace, not %s' % c)

    if type in pbm:
        # synthesize a MAXVAL
        header.append(1)
    depth = (1, 3)[type == b'P6']
    return header[0], header[1], header[2], depth, header[3]


def convert_pnm(w, infile, outfile):
    """
    Convert a PNM file containing raw pixel data into
    a PNG file with the parameters set in the writer object.
    Works for (binary) PGM, PPM, and PAM formats.
    """

    rows = scan_rows_from_file(infile, w.width, w.height, w.planes, w.bitdepth)
    w.write(outfile, rows)

def scan_rows_from_file(infile, width, height, planes, bitdepth):
    """
    Generate a sequence of rows from the input file `infile`.
    The input file should be in a "Netpbm-like" binary format.
    The input file should be positioned at the beginning of the first pixel.
    The number of pixels to read is taken from
    the image dimensions (`width`, `height`, `planes`);
    the number of bytes per value is implied by `bitdepth`.
    Each row is yielded as a single sequence of values.
    """

    # Values per row
    vpr = width * planes
    row_bytes = vpr
    if bitdepth > 8:
        assert bitdepth == 16
        row_bytes *= 2
        fmt = '>%dH' % vpr

        def line():
            return array('H', struct.unpack(fmt, infile.read(row_bytes)))
    else:
        def line():
            return array('B', infile.read(row_bytes))

    for y in range(height):
        yield line()


def add_options(parser):
    """
    Call *parser.add_argument* for each of the options.
    """
    parser.add_argument("-c", "--compression",
                        type=int, metavar="level",
                        help="zlib compression level (0-9)")
    parser.add_argument("file",
                        help="input PAM/PNM file to convert")
    return parser


def main(argv=None):
    if argv is None:
        argv = sys.argv

    # Parse command line arguments
    from argparse import ArgumentParser
    parser = ArgumentParser()
    version = '%(prog)s ' + png.__version__
    parser.add_argument("--version", action='version', version=version)
    add_options(parser)

    args = parser.parse_args(args=argv[1:])

    # Prepare input and output files
    infile = cli_open(args.file)

    # Call after parsing, so that --version and --help work.
    outfile = binary_stdout()

    # Encode PNM to PNG
    format, width, height, depth, maxval = read_pnm_header(infile)

    # The NetPBM depth (number of channels) completely
    # determines the PNG format.
    # Observe:
    # - L, LA, RGB, RGBA are the 4 modes supported by PNG;
    # - they correspond to 1, 2, 3, 4 channels respectively.
    # We use the number of channels in the source image to
    # determine which one we have.
    # We ignore the NetPBM image type and the PAM TUPLTYPE.
    greyscale = depth <= 2
    pamalpha = depth in (2, 4)
    supported = [2 ** x - 1 for x in range(1, 17)]
    try:
        mi = supported.index(maxval)
    except ValueError:
        raise NotImplementedError(
            'input maxval (%s) not in supported list %s' %
            (maxval, str(supported)))
    bitdepth = mi + 1
    writer = png.Writer(
        width, height,
        greyscale=greyscale,
        bitdepth=bitdepth,
        alpha=pamalpha,
        compression=args.compression)
    convert_pnm(writer, infile, outfile)


def cli_open(path):
    if path == "-":
        return binary_stdin()
    return open(path, 'rb')


def binary_stdin():
    """
    A sys.stdin that returns bytes.
    """

    try:
        sys.stdin = sys.stdin.buffer
    except AttributeError:
        # Probably Python 2, where bytes are strings.
        pass
    return sys.stdin


def binary_stdout():
    """
    A sys.stdout that accepts bytes.
    """

    # First there is a Python3 issue.
    try:
        stdout = sys.stdout.buffer
    except AttributeError:
        # Probably Python 2, where bytes are strings.
        stdout = sys.stdout

    # On Windows the C runtime file orientation needs changing.
    if sys.platform == "win32":
        import msvcrt
        import os
        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)

    return stdout


if __name__ == '__main__':
    try:
        sys.exit(main())
    except png.Error as e:
        print(e, file=sys.stderr)
        sys.exit(99)
