#!/usr/bin/env python
# pngchunk
# Chunk editing/extraction tool.

from __future__ import print_function

import argparse
import collections

# https://docs.python.org/2.7/library/io.html
import io
import string
import struct
import sys
import warnings
import zlib

# Local module.
import png


Chunk = collections.namedtuple("Chunk", "type content")


def process(out, args):
    """Process the PNG file args.input to the output, chunk by chunk.
    Chunks can be inserted, removed, replaced, or sometimes edited.
    Chunks are specified by their 4 byte Chunk Type;
    see https://www.w3.org/TR/2003/REC-PNG-20031110/#5Chunk-layout .
    The chunks in args.delete will be removed from the stream.
    The chunks in args.chunk will be inserted into the stream
    with their contents taken from the named files.

    Other options on the args object will create particular
    ancillary chunks.

    .gamma -> gAMA chunk
    .sigbit -> sBIT chunk

    Chunk types need not be official PNG chunks at all.
    Non-standard chunks can be created.
    """

    # Convert options to chunks in the args.chunk list
    if args.gamma:
        v = int(round(1e5 * args.gamma))
        bs = io.BytesIO(struct.pack(">I", v))
        args.chunk.insert(0, Chunk(b"gAMA", bs))
    if args.sigbit:
        v = struct.pack("%dB" % len(args.sigbit), *args.sigbit)
        bs = io.BytesIO(v)
        args.chunk.insert(0, Chunk(b"sBIT", bs))
    if args.iccprofile:
        # http://www.w3.org/TR/PNG/#11iCCP
        v = b"a color profile\x00\x00" + zlib.compress(args.iccprofile.read())
        bs = io.BytesIO(v)
        args.chunk.insert(0, Chunk(b"iCCP", bs))
    if args.transparent:
        # https://www.w3.org/TR/2003/REC-PNG-20031110/#11tRNS
        v = struct.pack(">%dH" % len(args.transparent), *args.transparent)
        bs = io.BytesIO(v)
        args.chunk.insert(0, Chunk(b"tRNS", bs))
    if args.background:
        # https://www.w3.org/TR/2003/REC-PNG-20031110/#11bKGD
        v = struct.pack(">%dH" % len(args.background), *args.background)
        bs = io.BytesIO(v)
        args.chunk.insert(0, Chunk(b"bKGD", bs))

    # Create:
    # - a set of chunks to delete
    # - a dict of chunks to replace
    # - a list of chunk to add

    delete = set(args.delete)
    # Generally, there should be at most one of the 'replacing' chunks.
    replacing = set([b"gAMA", b"sBIT", b"PLTE", b"tRNS", b"sPLT", b"IHDR"])
    replace = dict()
    add = []

    for chunk in args.chunk:
        if chunk.type in replacing:
            replace[chunk.type] = chunk
        else:
            add.append(chunk)

    input = png.Reader(file=args.input)

    return png.write_chunks(out, edit_chunks(input.chunks(), delete, replace, add))


def edit_chunks(chunks, delete, replace, add):
    """
    Iterate over chunks, yielding edited chunks.
    Subtle: the new chunks have to have their contents .read().
    """
    for type, v in chunks:
        if type in delete:
            continue
        if type in replace:
            yield type, replace[type].content.read()
            del replace[type]
            continue

        if b"IDAT" <= type <= b"IDAT" and replace:
            # If there are any chunks on the replace list by
            # the time we reach IDAT, add then all now.
            # put them all on the add list.
            for chunk in replace.values():
                yield chunk.type, chunk.content.read()
            replace = dict()

        if b"IDAT" <= type <= b"IDAT" and add:
            # We reached IDAT; add all remaining chunks now.
            for chunk in add:
                yield chunk.type, chunk.content.read()
            add = []

        yield type, v


def chunk_name(s):
    """
    Type check a chunk name option value.
    """

    # See https://www.w3.org/TR/2003/REC-PNG-20031110/#table51
    valid = len(s) == 4 and set(s) <= set(string.ascii_letters)
    if not valid:
        raise ValueError("Chunk name must be 4 ASCII letters")
    return s.encode("ascii")


def comma_list(s):
    """
    Convert s, a command separated list of whole numbers,
    into a list of int.
    """

    return tuple(int(v) for v in s.split(","))


def hex_color(s):
    """
    Type check and convert a hex color.
    """

    if s.startswith("#"):
        s = s[1:]
    valid = len(s) in [1, 2, 3, 4, 6, 12] and set(s) <= set(string.hexdigits)
    if not valid:
        raise ValueError("colour must be 1,2,3,4,6, or 12 hex-digits")

    # For the 4-bit RGB, expand to 8-bit, by repeating digits.
    if len(s) == 3:
        s = "".join(c + c for c in s)

    if len(s) in [1, 2, 4]:
        # Single grey value.
        return (int(s, 16),)

    if len(s) in [6, 12]:
        w = len(s) // 3
        return tuple(int(s[i : i + w], 16) for i in range(0, len(s), w))


def binary_stdout():
    try:
        # Probably Python 3
        return sys.stdout.buffer
    except AttributeError:
        pass
    return sys.stdout


def main(argv=None):
    import re

    if argv is None:
        argv = sys.argv

    argv = argv[1:]

    parser = argparse.ArgumentParser()
    parser.add_argument("--gamma", type=float)
    parser.add_argument("--sigbit", type=comma_list, metavar="D[,D[,D[,D]]]")
    parser.add_argument("--iccprofile", type=argparse.FileType("rb"))
    parser.add_argument("--transparent", type=hex_color, metavar="#RRGGBB")
    parser.add_argument(
        "--background",
        type=hex_color,
        metavar="#RRGGBB",
        help="background colour for bKGD chunk",
    )
    parser.add_argument(
        "--delete",
        action="append",
        default=[],
        type=chunk_name,
        help="delete the chunk",
    )
    parser.add_argument(
        "--chunk",
        action="append",
        nargs=2,
        default=[],
        type=str,
        help="insert chunk, taking contents from file",
    )
    parser.add_argument("input", type=argparse.FileType("rb"))

    args = parser.parse_args(argv)

    # Fix for binary stdin on Python 3
    try:
        args.input = args.input.buffer
    except AttributeError:
        pass

    # Reprocess the chunk arguments, converting each pair into a Chunk.
    args.chunk = [
        Chunk(chunk_name(type), open(path, "rb")) for type, path in args.chunk
    ]

    return process(binary_stdout(), args)


if __name__ == "__main__":
    main()
