import argparse
import os
from glob import glob
from typing import Any, List, Union

import gradio as gr
import numpy as np
import torch
import trimesh
from huggingface_hub import snapshot_download
from PIL import Image, ImageOps
from skimage import measure

from midi.pipelines.pipeline_midi import MIDIPipeline
from midi.utils.smoothing import smooth_gpu


def preprocess_image(rgb_image, seg_image):
    if isinstance(rgb_image, str):
        rgb_image = Image.open(rgb_image)
    if isinstance(seg_image, str):
        seg_image = Image.open(seg_image)
    rgb_image = rgb_image.convert("RGB")
    seg_image = seg_image.convert("L")

    width, height = rgb_image.size

    seg_np = np.array(seg_image)
    rows, cols = np.where(seg_np > 0)
    if rows.size == 0 or cols.size == 0:
        return rgb_image, seg_image

    # compute the bounding box of combined instances
    min_row, max_row = min(rows), max(rows)
    min_col, max_col = min(cols), max(cols)
    L = max(
        max(abs(max_row - width // 2), abs(min_row - width // 2)) * 2,
        max(abs(max_col - height // 2), abs(min_col - height // 2)) * 2,
    )

    # pad the image
    if L > width * 0.8:
        width = int(L / 4 * 5)
    if L > height * 0.8:
        height = int(L / 4 * 5)
    rgb_new = Image.new("RGB", (width, height), (255, 255, 255))
    seg_new = Image.new("L", (width, height), 0)
    x_offset = (width - rgb_image.size[0]) // 2
    y_offset = (height - rgb_image.size[1]) // 2
    rgb_new.paste(rgb_image, (x_offset, y_offset))
    seg_new.paste(seg_image, (x_offset, y_offset))

    # pad to the square
    max_dim = max(width, height)
    rgb_new = ImageOps.expand(
        rgb_new, border=(0, 0, max_dim - width, max_dim - height), fill="white"
    )
    seg_new = ImageOps.expand(
        seg_new, border=(0, 0, max_dim - width, max_dim - height), fill=0
    )

    return rgb_new, seg_new


def split_rgb_mask(rgb_image, seg_image):
    if isinstance(rgb_image, str):
        rgb_image = Image.open(rgb_image)
    if isinstance(seg_image, str):
        seg_image = Image.open(seg_image)
    rgb_image = rgb_image.convert("RGB")
    seg_image = seg_image.convert("L")

    rgb_array = np.array(rgb_image)
    seg_array = np.array(seg_image)

    label_ids = np.unique(seg_array)
    label_ids = label_ids[label_ids > 0]

    instance_rgbs, instance_masks, scene_rgbs = [], [], []

    for segment_id in sorted(label_ids):
        # Here we set the background to white
        white_background = np.ones_like(rgb_array) * 255

        mask = np.zeros_like(seg_array, dtype=np.uint8)
        mask[seg_array == segment_id] = 255
        segment_rgb = white_background.copy()
        segment_rgb[mask == 255] = rgb_array[mask == 255]

        segment_rgb_image = Image.fromarray(segment_rgb)
        segment_mask_image = Image.fromarray(mask)
        instance_rgbs.append(segment_rgb_image)
        instance_masks.append(segment_mask_image)
        scene_rgbs.append(rgb_image)

    return instance_rgbs, instance_masks, scene_rgbs


@torch.no_grad()
def run_midi(
    pipe: Any,
    rgb_image: Union[str, Image.Image],
    seg_image: Union[str, Image.Image],
    seed: int,
    num_inference_steps: int = 50,
    guidance_scale: float = 7.0,
    do_image_padding: bool = False,
) -> trimesh.Scene:
    if do_image_padding:
        rgb_image, seg_image = preprocess_image(rgb_image, seg_image)
    instance_rgbs, instance_masks, scene_rgbs = split_rgb_mask(rgb_image, seg_image)

    num_instances = len(instance_rgbs)
    outputs = pipe(
        image=instance_rgbs,
        mask=instance_masks,
        image_scene=scene_rgbs,
        attention_kwargs={"num_instances": num_instances},
        generator=torch.Generator(device=pipe.device).manual_seed(seed),
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        decode_progressive=True,
        return_dict=False,
    )

    # marching cubes
    trimeshes = []
    for _, (logits_, grid_size, bbox_size, bbox_min, bbox_max) in enumerate(
        zip(*outputs)
    ):
        grid_logits = logits_.view(grid_size)
        grid_logits = smooth_gpu(grid_logits, method="gaussian", sigma=1)
        torch.cuda.empty_cache()
        vertices, faces, normals, _ = measure.marching_cubes(
            grid_logits.float().cpu().numpy(), 0, method="lewiner"
        )
        vertices = vertices / grid_size * bbox_size + bbox_min

        # Trimesh
        mesh = trimesh.Trimesh(vertices.astype(np.float32), np.ascontiguousarray(faces))
        trimeshes.append(mesh)

    # compose the output meshes
    scene = trimesh.Scene(trimeshes)

    return scene


if __name__ == "__main__":
    device = "cuda"
    dtype = torch.bfloat16

    parser = argparse.ArgumentParser()
    parser.add_argument("--rgb", type=str, required=True)
    parser.add_argument("--seg", type=str, required=True)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num-inference-steps", type=int, default=50)
    parser.add_argument("--guidance-scale", type=float, default=7.0)
    parser.add_argument("--do-image-padding", action="store_true")
    parser.add_argument("--output-dir", type=str, default="./")
    args = parser.parse_args()

    local_dir = "pretrained_weights/MIDI-3D"
    snapshot_download(repo_id="VAST-AI/MIDI-3D", local_dir=local_dir)
    pipe: MIDIPipeline = MIDIPipeline.from_pretrained(local_dir).to(device, dtype)
    pipe.init_custom_adapter(
        set_self_attn_module_names=[
            "blocks.8",
            "blocks.9",
            "blocks.10",
            "blocks.11",
            "blocks.12",
        ]
    )

    run_midi(
        pipe,
        rgb_image=args.rgb,
        seg_image=args.seg,
        seed=args.seed,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        do_image_padding=args.do_image_padding,
    ).export(os.path.join(args.output_dir, "output.glb"))