import itertools
import json
import math
import os

import comfy.model_management as model_management
import folder_paths
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from skimage.filters import gaussian
from skimage.util import compare_images

from ..log import log
from ..utils import np2tensor, pil2tensor, tensor2pil

# try:
#     from cv2.ximgproc import guidedFilter
# except ImportError:
#     log.warning("cv2.ximgproc.guidedFilter not found, use opencv-contrib-python")


def gaussian_kernel(
    kernel_size: int, sigma_x: float, sigma_y: float, device=None
):
    x, y = torch.meshgrid(
        torch.linspace(-1, 1, kernel_size, device=device),
        torch.linspace(-1, 1, kernel_size, device=device),
        indexing="ij",
    )
    d_x = x * x / (2.0 * sigma_x * sigma_x)
    d_y = y * y / (2.0 * sigma_y * sigma_y)
    g = torch.exp(-(d_x + d_y))
    return g / g.sum()


class MTB_CoordinatesToString:
    RETURN_TYPES = ("STRING",)
    FUNCTION = "convert"
    CATEGORY = "mtb/coordinates"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "coordinates": ("BATCH_COORDINATES",),
                "frame": ("INT",),
            }
        }

    def convert(
        self, coordinates: list[list[tuple[int, int]]], frame: int
    ) -> tuple[str]:
        frame = max(frame, len(coordinates) - 1)
        coords = coordinates[frame]
        output: list[dict[str, int]] = []

        for x, y in coords:
            output.append({"x": x, "y": y})

        return (json.dumps(output),)


class MTB_ExtractCoordinatesFromImage:
    """Extract 2D points from a batch of images based on a threshold."""

    RETURN_TYPES = ("BATCH_COORDINATES", "IMAGE")
    FUNCTION = "extract"
    CATEGORY = "mtb/coordinates"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "threshold": ("FLOAT",),
                "max_points": ("INT", {"default": 50, "min": 0}),
            },
            "optional": {"image": ("IMAGE",), "mask": ("MASK",)},
        }

    def extract(
        self,
        threshold: float,
        max_points: int,
        image: torch.Tensor | None = None,
        mask: torch.Tensor | None = None,
    ) -> tuple[list[list[tuple[int, int]]], torch.Tensor]:
        if image is not None:
            batch_count, height, width, channel_count = image.shape
            imgs = image
        else:
            if mask is None:
                raise ValueError("Must provide either image or mask")
            batch_count, height, width = mask.shape
            channel_count = 1
            imgs = mask

        if channel_count not in [1, 2, 3, 4]:
            raise ValueError(f"Incorrect channel count: {channel_count}")

        all_points: list[list[tuple[int, int]]] = []
        debug_images = torch.zeros(
            (batch_count, height, width, 3),
            dtype=torch.uint8,
            device=imgs.device,
        )

        for i, img in enumerate(imgs):
            if channel_count == 1:
                alpha_channel = img if len(img.shape) == 2 else img[:, :, 0]
            elif channel_count == 2:
                alpha_channel = img[:, :, 1]
            elif channel_count == 4:
                alpha_channel = img[:, :, 3]
            else:
                # get intensity
                alpha_channel = img[:, :, :3].max(dim=2)[0]

            points = (alpha_channel > threshold).nonzero(as_tuple=False)

            if len(points) > max_points:
                indices = torch.randperm(points.size(0), device=img.device)[
                    :max_points
                ]
                points = points[indices]

            points = [(int(y.item()), int(x.item())) for x, y in points]
            all_points.append(points)

            for x, y in points:
                self._draw_circle(debug_images[i], (x, y), 5)

        return (all_points, debug_images)

    @staticmethod
    def _draw_circle(
        image: torch.Tensor, center: tuple[int, int], radius: int
    ):
        """Draw a 5px circle on the image."""
        x0, y0 = center
        for x in range(-radius, radius + 1):
            for y in range(-radius, radius + 1):
                in_radius = x**2 + y**2 <= radius**2
                in_bounds = (
                    0 <= x0 + x < image.shape[1]
                    and 0 <= y0 + y < image.shape[0]
                )
                if in_radius and in_bounds:
                    image[y0 + y, x0 + x] = torch.tensor(
                        [255, 255, 255],
                        dtype=torch.uint8,
                        device=image.device,
                    )


class MTB_ColorCorrectGPU:
    """Various color correction methods using only Torch."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "force_gpu": ("BOOLEAN", {"default": True}),
                "clamp": ([True, False], {"default": True}),
                "gamma": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "contrast": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "exposure": (
                    "FLOAT",
                    {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01},
                ),
                "offset": (
                    "FLOAT",
                    {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01},
                ),
                "hue": (
                    "FLOAT",
                    {"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01},
                ),
                "saturation": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "value": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
            },
            "optional": {"mask": ("MASK",)},
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "correct"
    CATEGORY = "mtb/image processing"

    @staticmethod
    def get_device(tensor: torch.Tensor, force_gpu: bool):
        if force_gpu:
            if torch.cuda.is_available():
                return torch.device("cuda")
            elif (
                hasattr(torch.backends, "mps")
                and torch.backends.mps.is_available()
            ):
                return torch.device("mps")
            elif hasattr(torch, "hip") and torch.hip.is_available():
                return torch.device("hip")
        return (
            tensor.device
        )  # model_management.get_torch_device() # torch.device("cpu")

    @staticmethod
    def rgb_to_hsv(image: torch.Tensor):
        r, g, b = image.unbind(-1)
        max_rgb, argmax_rgb = image.max(-1)
        min_rgb, _ = image.min(-1)

        diff = max_rgb - min_rgb

        h = torch.empty_like(max_rgb)
        s = diff / (max_rgb + 1e-7)
        v = max_rgb

        h[argmax_rgb == 0] = (g - b)[argmax_rgb == 0] / (diff + 1e-7)[
            argmax_rgb == 0
        ]
        h[argmax_rgb == 1] = (
            2.0 + (b - r)[argmax_rgb == 1] / (diff + 1e-7)[argmax_rgb == 1]
        )
        h[argmax_rgb == 2] = (
            4.0 + (r - g)[argmax_rgb == 2] / (diff + 1e-7)[argmax_rgb == 2]
        )
        h = (h / 6.0) % 1.0

        h = h.unsqueeze(-1)
        s = s.unsqueeze(-1)
        v = v.unsqueeze(-1)

        return torch.cat((h, s, v), dim=-1)

    @staticmethod
    def hsv_to_rgb(hsv: torch.Tensor):
        h, s, v = hsv.unbind(-1)
        h = h * 6.0

        i = torch.floor(h)
        f = h - i
        p = v * (1.0 - s)
        q = v * (1.0 - s * f)
        t = v * (1.0 - s * (1.0 - f))

        i = i.long() % 6

        mask = torch.stack(
            (i == 0, i == 1, i == 2, i == 3, i == 4, i == 5), -1
        )

        rgb = torch.stack(
            (
                torch.where(
                    mask[..., 0],
                    v,
                    torch.where(
                        mask[..., 1],
                        q,
                        torch.where(
                            mask[..., 2],
                            p,
                            torch.where(
                                mask[..., 3],
                                p,
                                torch.where(mask[..., 4], t, v),
                            ),
                        ),
                    ),
                ),
                torch.where(
                    mask[..., 0],
                    t,
                    torch.where(
                        mask[..., 1],
                        v,
                        torch.where(
                            mask[..., 2],
                            v,
                            torch.where(
                                mask[..., 3],
                                q,
                                torch.where(mask[..., 4], p, p),
                            ),
                        ),
                    ),
                ),
                torch.where(
                    mask[..., 0],
                    p,
                    torch.where(
                        mask[..., 1],
                        p,
                        torch.where(
                            mask[..., 2],
                            t,
                            torch.where(
                                mask[..., 3],
                                v,
                                torch.where(mask[..., 4], v, q),
                            ),
                        ),
                    ),
                ),
            ),
            dim=-1,
        )

        return rgb

    def correct(
        self,
        image: torch.Tensor,
        force_gpu: bool,
        clamp: bool,
        gamma: float = 1.0,
        contrast: float = 1.0,
        exposure: float = 0.0,
        offset: float = 0.0,
        hue: float = 0.0,
        saturation: float = 1.0,
        value: float = 1.0,
        mask: torch.Tensor | None = None,
    ):
        device = self.get_device(image, force_gpu)
        image = image.to(device)

        if mask is not None:
            if mask.shape[0] != image.shape[0]:
                mask = mask.expand(image.shape[0], -1, -1)

            mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3)
            mask = mask.to(device)

        model_management.throw_exception_if_processing_interrupted()
        adjusted = image.pow(1 / gamma) * (2.0**exposure) * contrast + offset

        model_management.throw_exception_if_processing_interrupted()
        hsv = self.rgb_to_hsv(adjusted)
        hsv[..., 0] = (hsv[..., 0] + hue) % 1.0  # Hue
        hsv[..., 1] = hsv[..., 1] * saturation  # Saturation
        hsv[..., 2] = hsv[..., 2] * value  # Value
        adjusted = self.hsv_to_rgb(hsv)

        model_management.throw_exception_if_processing_interrupted()
        if clamp:
            adjusted = torch.clamp(adjusted, 0.0, 1.0)

        # apply mask
        result = (
            adjusted
            if mask is None
            else torch.where(mask > 0, adjusted, image)
        )

        if not force_gpu:
            result = result.cpu()

        return (result,)


class MTB_ColorCorrect:
    """Various color correction methods"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "clamp": ([True, False], {"default": True}),
                "gamma": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "contrast": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "exposure": (
                    "FLOAT",
                    {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01},
                ),
                "offset": (
                    "FLOAT",
                    {"default": 0.0, "min": -5.0, "max": 5.0, "step": 0.01},
                ),
                "hue": (
                    "FLOAT",
                    {"default": 0.0, "min": -0.5, "max": 0.5, "step": 0.01},
                ),
                "saturation": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
                "value": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.01},
                ),
            },
            "optional": {"mask": ("MASK",)},
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "correct"
    CATEGORY = "mtb/image processing"

    @staticmethod
    def gamma_correction_tensor(image, gamma):
        gamma_inv = 1.0 / gamma
        return image.pow(gamma_inv)

    @staticmethod
    def contrast_adjustment_tensor(image, contrast):
        r, g, b = image.unbind(-1)

        # Using Adobe RGB luminance weights.
        luminance_image = 0.33 * r + 0.71 * g + 0.06 * b
        luminance_mean = torch.mean(luminance_image.unsqueeze(-1))

        # Blend original with mean luminance using contrast factor as blend ratio.
        contrasted = image * contrast + (1.0 - contrast) * luminance_mean
        return torch.clamp(contrasted, 0.0, 1.0)

    @staticmethod
    def exposure_adjustment_tensor(image, exposure):
        return image * (2.0**exposure)

    @staticmethod
    def offset_adjustment_tensor(image, offset):
        return image + offset

    @staticmethod
    def hsv_adjustment(image: torch.Tensor, hue, saturation, value):
        images = tensor2pil(image)
        out = []
        for img in images:
            hsv_image = img.convert("HSV")

            h, s, v = hsv_image.split()

            h = h.point(lambda x: (x + hue * 255) % 256)
            s = s.point(lambda x: int(x * saturation))
            v = v.point(lambda x: int(x * value))

            hsv_image = Image.merge("HSV", (h, s, v))
            rgb_image = hsv_image.convert("RGB")
            out.append(rgb_image)
        return pil2tensor(out)

    @staticmethod
    def hsv_adjustment_tensor_not_working(
        image: torch.Tensor, hue, saturation, value
    ):
        """Abandonning for now"""
        image = image.squeeze(0).permute(2, 0, 1)

        max_val, _ = image.max(dim=0, keepdim=True)
        min_val, _ = image.min(dim=0, keepdim=True)
        delta = max_val - min_val

        hue_image = torch.zeros_like(max_val)
        mask = delta != 0.0

        r, g, b = image[0], image[1], image[2]
        hue_image[mask & (max_val == r)] = ((g - b) / delta)[
            mask & (max_val == r)
        ] % 6.0
        hue_image[mask & (max_val == g)] = ((b - r) / delta)[
            mask & (max_val == g)
        ] + 2.0
        hue_image[mask & (max_val == b)] = ((r - g) / delta)[
            mask & (max_val == b)
        ] + 4.0

        saturation_image = delta / (max_val + 1e-7)
        value_image = max_val

        hue_image = (hue_image + hue) % 1.0
        saturation_image = torch.where(
            mask, saturation * saturation_image, saturation_image
        )
        value_image = value * value_image

        c = value_image * saturation_image
        x = c * (1 - torch.abs((hue_image % 2) - 1))
        m = value_image - c

        prime_image = torch.zeros_like(image)
        prime_image[0] = torch.where(
            max_val == r, c, torch.where(max_val == g, x, prime_image[0])
        )
        prime_image[1] = torch.where(
            max_val == r, x, torch.where(max_val == g, c, prime_image[1])
        )
        prime_image[2] = torch.where(
            max_val == g, x, torch.where(max_val == b, c, prime_image[2])
        )

        rgb_image = prime_image + m

        rgb_image = rgb_image.permute(1, 2, 0).unsqueeze(0)

        return rgb_image

    def correct(
        self,
        image: torch.Tensor,
        clamp: bool,
        gamma: float = 1.0,
        contrast: float = 1.0,
        exposure: float = 0.0,
        offset: float = 0.0,
        hue: float = 0.0,
        saturation: float = 1.0,
        value: float = 1.0,
        mask: torch.Tensor | None = None,
    ):
        if mask is not None:
            if mask.shape[0] != image.shape[0]:
                mask = mask.expand(image.shape[0], -1, -1)

            mask = mask.unsqueeze(-1).expand(-1, -1, -1, 3)

        # Apply color correction operations
        adjusted = self.gamma_correction_tensor(image, gamma)
        adjusted = self.contrast_adjustment_tensor(adjusted, contrast)
        adjusted = self.exposure_adjustment_tensor(adjusted, exposure)
        adjusted = self.offset_adjustment_tensor(adjusted, offset)
        adjusted = self.hsv_adjustment(adjusted, hue, saturation, value)

        if clamp:
            adjusted = torch.clamp(image, 0.0, 1.0)

        result = (
            adjusted
            if mask is None
            else torch.where(mask > 0, adjusted, image)
        )

        return (result,)


class MTB_ImageCompare:
    """Compare two images and return a difference image"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "imageA": ("IMAGE",),
                "imageB": ("IMAGE",),
                "mode": (
                    ["checkerboard", "diff", "blend"],
                    {"default": "checkerboard"},
                ),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "compare"
    CATEGORY = "mtb/image"

    def compare(self, imageA: torch.Tensor, imageB: torch.Tensor, mode):
        if imageA.dim() == 4:
            batch_count = imageA.size(0)
            return (
                torch.cat(
                    tuple(
                        self.compare(imageA[i], imageB[i], mode)[0]
                        for i in range(batch_count)
                    ),
                    dim=0,
                ),
            )

        num_channels_A = imageA.size(2)
        num_channels_B = imageB.size(2)

        # handle RGBA/RGB mismatch
        if num_channels_A == 3 and num_channels_B == 4:
            imageA = torch.cat(
                (imageA, torch.ones_like(imageA[:, :, 0:1])), dim=2
            )
        elif num_channels_B == 3 and num_channels_A == 4:
            imageB = torch.cat(
                (imageB, torch.ones_like(imageB[:, :, 0:1])), dim=2
            )
        match mode:
            case "diff":
                compare_image = torch.abs(imageA - imageB)
            case "blend":
                compare_image = 0.5 * (imageA + imageB)
            case "checkerboard":
                imageA = imageA.numpy()
                imageB = imageB.numpy()
                compared_channels = [
                    torch.from_numpy(
                        compare_images(
                            imageA[:, :, i], imageB[:, :, i], method=mode
                        )
                    )
                    for i in range(imageA.shape[2])
                ]

                compare_image = torch.stack(compared_channels, dim=2)
            case _:
                compare_image = None
                raise ValueError(f"Unknown mode {mode}")

        compare_image = compare_image.unsqueeze(0)

        return (compare_image,)


import requests


class MTB_LoadImageFromUrl:
    """Load an image from the given URL"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "url": (
                    "STRING",
                    {
                        "default": "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
                    },
                ),
            }
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "load"
    CATEGORY = "mtb/IO"

    def load(self, url):
        # get the image from the url
        image = Image.open(requests.get(url, stream=True).raw)
        image = ImageOps.exif_transpose(image)
        return (pil2tensor(image),)


class MTB_Blur:
    """Blur an image using a Gaussian filter."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "sigmaX": (
                    "FLOAT",
                    {"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01},
                ),
                "sigmaY": (
                    "FLOAT",
                    {"default": 3.0, "min": 0.0, "max": 200.0, "step": 0.01},
                ),
            },
            "optional": {"sigmasX": ("FLOATS",), "sigmasY": ("FLOATS",)},
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "blur"
    CATEGORY = "mtb/image processing"

    def blur(
        self, image: torch.Tensor, sigmaX, sigmaY, sigmasX=None, sigmasY=None
    ):
        image_np = image.numpy() * 255

        blurred_images = []
        if sigmasX is not None:
            if sigmasY is None:
                sigmasY = sigmasX
            if len(sigmasX) != image.size(0):
                raise ValueError(
                    f"SigmasX must have same length as image, sigmasX is {len(sigmasX)} but the batch size is {image.size(0)}"
                )

            for i in range(image.size(0)):
                blurred = gaussian(
                    image_np[i],
                    sigma=(sigmasX[i], sigmasY[i], 0),
                    channel_axis=2,
                )
                blurred_images.append(blurred)

            image_np = np.array(blurred_images)
        else:
            for i in range(image.size(0)):
                blurred = gaussian(
                    image_np[i], sigma=(sigmaX, sigmaY, 0), channel_axis=2
                )
                blurred_images.append(blurred)

            image_np = np.array(blurred_images)
        return (np2tensor(image_np).squeeze(0),)


class MTB_Sharpen:
    """Sharpens an image using a Gaussian kernel."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "sharpen_radius": (
                    "INT",
                    {"default": 1, "min": 1, "max": 31, "step": 1},
                ),
                "sigma_x": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1},
                ),
                "sigma_y": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1},
                ),
                "alpha": (
                    "FLOAT",
                    {"default": 1.0, "min": 0.0, "max": 5.0, "step": 0.1},
                ),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "do_sharp"
    CATEGORY = "mtb/image processing"

    def do_sharp(
        self,
        image: torch.Tensor,
        sharpen_radius: int,
        sigma_x: float,
        sigma_y: float,
        alpha: float,
    ):
        if sharpen_radius == 0:
            return (image,)

        channels = image.shape[3]

        kernel_size = 2 * sharpen_radius + 1
        kernel = gaussian_kernel(kernel_size, sigma_x, sigma_y) * -(alpha * 10)

        # Modify center of kernel to make it a sharpening kernel
        center = kernel_size // 2
        kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0

        kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
        tensor_image = image.permute(0, 3, 1, 2)

        tensor_image = F.pad(
            tensor_image,
            (sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius),
            "reflect",
        )
        sharpened = F.conv2d(
            tensor_image, kernel, padding=center, groups=channels
        )

        # Remove padding
        sharpened = sharpened[
            :,
            :,
            sharpen_radius:-sharpen_radius,
            sharpen_radius:-sharpen_radius,
        ]

        sharpened = sharpened.permute(0, 2, 3, 1)
        result = torch.clamp(sharpened, 0, 1)

        return (result,)


# https://github.com/lllyasviel/AdverseCleaner/blob/main/clean.py
# def deglaze_np_img(np_img):
#     y = np_img.copy()
#     for _ in range(64):
#         y = cv2.bilateralFilter(y, 5, 8, 8)
#     for _ in range(4):
#         y = guidedFilter(np_img, y, 4, 16)
#     return y


# class DeglazeImage:
#     """Remove adversarial noise from images"""

#     @classmethod
#     def INPUT_TYPES(cls):
#         return {"required": {"image": ("IMAGE",)}}

#     CATEGORY = "mtb/image processing"

#     RETURN_TYPES = ("IMAGE",)
#     FUNCTION = "deglaze_image"

#     def deglaze_image(self, image):
#         return (np2tensor(deglaze_np_img(tensor2np(image))),)


class MTB_MaskToImage:
    """Converts a mask (alpha) to an RGB image with a color and background"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "mask": ("MASK",),
                "color": ("COLOR",),
                "background": ("COLOR", {"default": "#000000"}),
            },
            "optional": {
                "invert": ("BOOLEAN", {"default": False}),
            },
        }

    CATEGORY = "mtb/generate"

    RETURN_TYPES = ("IMAGE",)

    FUNCTION = "render_mask"

    def render_mask(self, mask, color, background, invert=False):
        masks = tensor2pil(1.0 - mask) if invert else tensor2pil(mask)
        images = []

        for m in masks:
            _mask = m.convert("L")

            log.debug(
                f"Converted mask to PIL Image format, size: {_mask.size}"
            )

            image = Image.new("RGBA", _mask.size, color=color)
            # apply the mask
            image = Image.composite(
                image, Image.new("RGBA", _mask.size, color=background), _mask
            )

            # image = ImageChops.multiply(image, mask)
            # apply over background
            # image = Image.alpha_composite(Image.new("RGBA", image.size, color=background), image)

            images.append(image.convert("RGB"))

        return (pil2tensor(images),)


class MTB_ColoredImage:
    """Constant color image of given size."""

    def __init__(self) -> None:
        pass

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "color": ("COLOR",),
                "width": ("INT", {"default": 512, "min": 16, "max": 8160}),
                "height": ("INT", {"default": 512, "min": 16, "max": 8160}),
            },
            "optional": {
                "foreground_image": ("IMAGE",),
                "foreground_mask": ("MASK",),
                "invert": ("BOOLEAN", {"default": False}),
                "mask_opacity": (
                    "FLOAT",
                    {"default": 1.0, "step": 0.1, "min": 0},
                ),
            },
        }

    CATEGORY = "mtb/generate"

    RETURN_TYPES = ("IMAGE",)

    FUNCTION = "render_img"

    def resize_and_crop(self, img: Image.Image, target_size: tuple[int, int]):
        scale = max(target_size[0] / img.width, target_size[1] / img.height)
        new_size = (int(img.width * scale), int(img.height * scale))
        img = img.resize(new_size, Image.LANCZOS)
        left = (img.width - target_size[0]) // 2
        top = (img.height - target_size[1]) // 2
        return img.crop(
            (left, top, left + target_size[0], top + target_size[1])
        )

    def resize_and_crop_thumbnails(
        self, img: Image.Image, target_size: tuple[int, int]
    ):
        img.thumbnail(target_size, Image.LANCZOS)
        left = (img.width - target_size[0]) / 2
        top = (img.height - target_size[1]) / 2
        right = (img.width + target_size[0]) / 2
        bottom = (img.height + target_size[1]) / 2
        return img.crop((left, top, right, bottom))

    @staticmethod
    def process_mask(
        mask: torch.Tensor | None,
        invert: bool,
        # opacity: float,
        batch_size: int,
    ) -> list[Image.Image] | None:
        if mask is None:
            return [None] * batch_size

        masks = tensor2pil(mask if not invert else 1.0 - mask)

        if len(masks) == 1 and batch_size > 1:
            masks = masks * batch_size

        if len(masks) != batch_size:
            raise ValueError(
                "Foreground image and mask must have the same batch size"
            )

        return masks

    def render_img(
        self,
        color: str,
        width: int,
        height: int,
        foreground_image: torch.Tensor | None = None,
        foreground_mask: torch.Tensor | None = None,
        invert: bool = False,
        mask_opacity: float = 1.0,
    ) -> tuple[torch.Tensor]:
        background = Image.new("RGBA", (width, height), color=color)

        if foreground_image is None:
            return (pil2tensor([background.convert("RGB")]),)

        fg_images = tensor2pil(foreground_image)
        fg_masks = self.process_mask(foreground_mask, invert, len(fg_images))

        output: list[Image.Image] = []
        for fg_image, fg_mask in zip(fg_images, fg_masks, strict=False):
            fg_image = self.resize_and_crop(fg_image, background.size)

            if fg_mask:
                fg_mask = self.resize_and_crop(fg_mask, background.size)

                fg_mask_array = np.array(fg_mask)
                fg_mask_array = (fg_mask_array * mask_opacity).astype(np.uint8)
                fg_mask = Image.fromarray(fg_mask_array)
                output.append(
                    Image.composite(
                        fg_image.convert("RGBA"), background, fg_mask
                    ).convert("RGB")
                )
            else:
                if fg_image.mode != "RGBA":
                    raise ValueError(
                        f"Foreground image must be in 'RGBA' mode when no mask is provided, got {fg_image.mode}"
                    )
                output.append(
                    Image.alpha_composite(background, fg_image).convert("RGB")
                )

        return (pil2tensor(output),)


class MTB_ImagePremultiply:
    """Premultiply image with mask"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "mask": ("MASK",),
                "invert": ("BOOLEAN", {"default": False}),
            }
        }

    CATEGORY = "mtb/image"
    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("RGBA",)
    FUNCTION = "premultiply"

    def premultiply(self, image, mask, invert):
        images = tensor2pil(image)
        masks = tensor2pil(mask) if invert else tensor2pil(1.0 - mask)
        single = len(mask) == 1
        masks = [x.convert("L") for x in masks]

        out = []
        for i, img in enumerate(images):
            cur_mask = masks[0] if single else masks[i]

            img.putalpha(cur_mask)
            out.append(img)

        # if invert:
        #     image = Image.composite(image,Image.new("RGBA", image.size, color=(0,0,0,0)), mask)
        # else:
        #     image = Image.composite(Image.new("RGBA", image.size, color=(0,0,0,0)), image, mask)

        return (pil2tensor(out),)


class MTB_ImageResizeFactor:
    """Extracted mostly from WAS Node Suite, with a few edits (most notably multiple image support) and less features."""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "factor": (
                    "FLOAT",
                    {"default": 2, "min": 0.01, "max": 16.0, "step": 0.01},
                ),
                "supersample": ("BOOLEAN", {"default": True}),
                "resampling": (
                    [
                        "nearest",
                        "linear",
                        "bilinear",
                        "bicubic",
                        "trilinear",
                        "area",
                        "nearest-exact",
                    ],
                    {"default": "nearest"},
                ),
            },
            "optional": {
                "mask": ("MASK",),
            },
        }

    CATEGORY = "mtb/image"
    RETURN_TYPES = ("IMAGE", "MASK")
    FUNCTION = "resize"

    def resize(
        self,
        image: torch.Tensor,
        factor: float,
        supersample: bool,
        resampling: str,
        mask=None,
    ):
        # Check if the tensor has the correct dimension
        if len(image.shape) not in [3, 4]:  # HxWxC or BxHxWxC
            raise ValueError(
                "Expected image tensor of shape (H, W, C) or (B, H, W, C)"
            )

        # Transpose to CxHxW or BxCxHxW for PyTorch
        if len(image.shape) == 3:
            image = image.permute(2, 0, 1).unsqueeze(0)  # CxHxW
        else:
            image = image.permute(0, 3, 1, 2)  # BxCxHxW

        # Compute new dimensions
        B, C, H, W = image.shape
        new_H, new_W = int(H * factor), int(W * factor)

        align_corner_filters = ("linear", "bilinear", "bicubic", "trilinear")
        # Resize the image
        resized_image = F.interpolate(
            image,
            size=(new_H, new_W),
            mode=resampling,
            align_corners=resampling in align_corner_filters,
        )

        # Optionally supersample
        if supersample:
            resized_image = F.interpolate(
                resized_image,
                scale_factor=2,
                mode=resampling,
                align_corners=resampling in align_corner_filters,
            )

        # Transpose back to the original format: BxHxWxC or HxWxC
        if len(image.shape) == 4:
            resized_image = resized_image.permute(0, 2, 3, 1)
        else:
            resized_image = resized_image.squeeze(0).permute(1, 2, 0)

        # Apply mask if provided
        if mask is not None:
            if len(mask.shape) != len(resized_image.shape):
                raise ValueError(
                    "Mask tensor should have the same dimensions as the image tensor"
                )
            resized_image = resized_image * mask

        return (resized_image,)


class MTB_SaveImageGrid:
    """Save all the images in the input batch as a grid of images."""

    def __init__(self):
        self.output_dir = folder_paths.get_output_directory()
        self.type = "output"

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "images": ("IMAGE",),
                "filename_prefix": ("STRING", {"default": "ComfyUI"}),
                "save_intermediate": ("BOOLEAN", {"default": False}),
            },
            "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
        }

    RETURN_TYPES = ()
    FUNCTION = "save_images"

    OUTPUT_NODE = True

    CATEGORY = "mtb/IO"

    def create_image_grid(self, image_list):
        total_images = len(image_list)

        # Calculate the grid size based on the square root of the total number of images
        grid_size = (
            int(math.sqrt(total_images)),
            int(math.ceil(math.sqrt(total_images))),
        )

        # Get the size of the first image to determine the grid size
        image_width, image_height = image_list[0].size

        # Create a new blank image to hold the grid
        grid_width = grid_size[0] * image_width
        grid_height = grid_size[1] * image_height
        grid_image = Image.new("RGB", (grid_width, grid_height))

        # Iterate over the images and paste them onto the grid
        for i, image in enumerate(image_list):
            x = (i % grid_size[0]) * image_width
            y = (i // grid_size[0]) * image_height
            grid_image.paste(image, (x, y, x + image_width, y + image_height))

        return grid_image

    def save_images(
        self,
        images,
        filename_prefix="Grid",
        save_intermediate=False,
        prompt=None,
        extra_pnginfo=None,
    ):
        (
            full_output_folder,
            filename,
            counter,
            subfolder,
            filename_prefix,
        ) = folder_paths.get_save_image_path(
            filename_prefix,
            self.output_dir,
            images[0].shape[1],
            images[0].shape[0],
        )
        image_list = []
        batch_counter = counter

        metadata = PngInfo()
        if prompt is not None:
            metadata.add_text("prompt", json.dumps(prompt))
        if extra_pnginfo is not None:
            for x in extra_pnginfo:
                metadata.add_text(x, json.dumps(extra_pnginfo[x]))

        for idx, image in enumerate(images):
            i = 255.0 * image.cpu().numpy()
            img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
            image_list.append(img)

            if save_intermediate:
                file = f"{filename}_batch-{idx:03}_{batch_counter:05}_.png"
                img.save(
                    os.path.join(full_output_folder, file),
                    pnginfo=metadata,
                    compress_level=4,
                )

            batch_counter += 1

        file = f"{filename}_{counter:05}_.png"
        grid = self.create_image_grid(image_list)
        grid.save(
            os.path.join(full_output_folder, file),
            pnginfo=metadata,
            compress_level=4,
        )

        results = [
            {"filename": file, "subfolder": subfolder, "type": self.type}
        ]
        return {"ui": {"images": results}}


class MTB_ImageTileOffset:
    """Mimics an old photoshop technique to check for seamless textures"""

    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "image": ("IMAGE",),
                "tilesX": ("INT", {"default": 2, "min": 1}),
                "tilesY": ("INT", {"default": 2, "min": 1}),
            }
        }

    CATEGORY = "mtb/generate"

    RETURN_TYPES = ("IMAGE",)

    FUNCTION = "tile_image"

    def tile_image(
        self, image: torch.Tensor, tilesX: int = 2, tilesY: int = 2
    ):
        if tilesX < 1 or tilesY < 1:
            raise ValueError("The number of tiles must be at least 1.")

        batch_size, height, width, channels = image.shape
        tile_height = height // tilesY
        tile_width = width // tilesX

        output_image = torch.zeros_like(image)

        for i, j in itertools.product(range(tilesY), range(tilesX)):
            start_h = i * tile_height
            end_h = start_h + tile_height
            start_w = j * tile_width
            end_w = start_w + tile_width

            tile = image[:, start_h:end_h, start_w:end_w, :]

            output_start_h = (i + 1) % tilesY * tile_height
            output_start_w = (j + 1) % tilesX * tile_width
            output_end_h = output_start_h + tile_height
            output_end_w = output_start_w + tile_width

            output_image[
                :, output_start_h:output_end_h, output_start_w:output_end_w, :
            ] = tile

        return (output_image,)


__nodes__ = [
    MTB_ColorCorrect,
    MTB_ColorCorrectGPU,
    MTB_ImageCompare,
    MTB_ImageTileOffset,
    MTB_Blur,
    # DeglazeImage,
    MTB_MaskToImage,
    MTB_ColoredImage,
    MTB_ImagePremultiply,
    MTB_ImageResizeFactor,
    MTB_SaveImageGrid,
    MTB_LoadImageFromUrl,
    MTB_Sharpen,
    MTB_ExtractCoordinatesFromImage,
    MTB_CoordinatesToString,
]