import asyncio
import io
import json
import os
import sys
from typing import IO

import click
from PIL import Image

from ..bg import remove
from ..session_factory import new_session
from ..sessions import sessions_names


@click.command(
    name="b",
    help="for a byte stream as input",
)
@click.option(
    "-m",
    "--model",
    default="u2net",
    type=click.Choice(sessions_names),
    show_default=True,
    show_choices=True,
    help="model name",
)
@click.option(
    "-a",
    "--alpha-matting",
    is_flag=True,
    show_default=True,
    help="use alpha matting",
)
@click.option(
    "-af",
    "--alpha-matting-foreground-threshold",
    default=240,
    type=int,
    show_default=True,
    help="trimap fg threshold",
)
@click.option(
    "-ab",
    "--alpha-matting-background-threshold",
    default=10,
    type=int,
    show_default=True,
    help="trimap bg threshold",
)
@click.option(
    "-ae",
    "--alpha-matting-erode-size",
    default=10,
    type=int,
    show_default=True,
    help="erode size",
)
@click.option(
    "-om",
    "--only-mask",
    is_flag=True,
    show_default=True,
    help="output only the mask",
)
@click.option(
    "-ppm",
    "--post-process-mask",
    is_flag=True,
    show_default=True,
    help="post process the mask",
)
@click.option(
    "-bgc",
    "--bgcolor",
    default=None,
    type=(int, int, int, int),
    nargs=4,
    help="Background color (R G B A) to replace the removed background with",
)
@click.option("-x", "--extras", type=str)
@click.option(
    "-o",
    "--output_specifier",
    type=str,
    help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
)
@click.argument(
    "image_width",
    type=int,
)
@click.argument(
    "image_height",
    type=int,
)
def rs_command(
    model: str,
    extras: str,
    image_width: int,
    image_height: int,
    output_specifier: str,
    **kwargs
) -> None:
    try:
        kwargs.update(json.loads(extras))
    except Exception:
        pass

    session = new_session(model)
    bytes_per_img = image_width * image_height * 3

    if output_specifier:
        output_dir = os.path.dirname(
            os.path.abspath(os.path.expanduser(output_specifier))
        )

        if not os.path.isdir(output_dir):
            os.makedirs(output_dir, exist_ok=True)

    def img_to_byte_array(img: Image) -> bytes:
        buff = io.BytesIO()
        img.save(buff, format="PNG")
        return buff.getvalue()

    async def connect_stdin_stdout():
        loop = asyncio.get_event_loop()
        reader = asyncio.StreamReader()
        protocol = asyncio.StreamReaderProtocol(reader)

        await loop.connect_read_pipe(lambda: protocol, sys.stdin)
        w_transport, w_protocol = await loop.connect_write_pipe(
            asyncio.streams.FlowControlMixin, sys.stdout
        )

        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
        return reader, writer

    async def main():
        reader, writer = await connect_stdin_stdout()

        idx = 0
        while True:
            try:
                img_bytes = await reader.readexactly(bytes_per_img)
                if not img_bytes:
                    break

                img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
                output = remove(img, session=session, **kwargs)

                if output_specifier:
                    output.save((output_specifier % idx), format="PNG")
                else:
                    writer.write(img_to_byte_array(output))

                idx += 1
            except asyncio.IncompleteReadError:
                break

    asyncio.run(main())