import gradio as gr
import numpy as np
import torch
from PIL import Image
import spaces
from omegaconf import OmegaConf

import subprocess
rc = subprocess.call("./setup.sh")

import sys 
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lama'))

from lama.saicinpainting.evaluation.refinement import refine_predict
from lama.saicinpainting.training.trainers import load_checkpoint
from lama.saicinpainting.evaluation.utils import move_to_device


# Load the model
def get_inpaint_model():
    """
    Loads and initializes the inpainting model.
    Returns: Tuple of (model, predict_config)
    """
    predict_config = OmegaConf.load('./default.yaml')
    predict_config.model.path = './big-lama/models/'
    predict_config.refiner.gpu_ids = '0'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instead of setting device directly, we'll use it when loading the model
    predict_config.device = str(device)  # Store as string in config
    train_config_path = './big-lama/config.yaml'

    train_config = OmegaConf.load(train_config_path)
    train_config.training_model.predict_only = True
    train_config.visualizer.kind = 'noop'

    checkpoint_path = os.path.join(predict_config.model.path, 
                                   predict_config.model.checkpoint)

    model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location=device)
    model.freeze()
    model.to(device)
    return model, predict_config

@spaces.GPU
def inpaint(input_dict, refinement_enabled=False):
    """
    Performs image inpainting on the input image using the provided mask.
    Args: input_dict containing 'background' (image) and 'layers' (mask)
    Returns: Tuple of (output_image, input_mask)
    """
    input_image = np.array(input_dict["background"].convert("RGB")).astype('float32') / 255
    input_mask = pil_to_binary_mask(input_dict['layers'][0])

    np_input_image = np.transpose(np.array(input_image), (2, 0, 1))
    np_input_mask = np.array(input_mask)[None, :, :]  # Add channel dimension for grayscale images
    batch = dict(image=np_input_image, mask=np_input_mask)

    inpaint_model, predict_config = get_inpaint_model()
    device = torch.device(predict_config.device)

    batch['unpad_to_size'] = [torch.tensor([batch['image'].shape[1]]),torch.tensor([batch['image'].shape[2]])]
    batch['image'] = torch.tensor(pad_img_to_modulo(batch['image'], predict_config.dataset.pad_out_to_modulo))[None].to(device)
    batch['mask'] = torch.tensor(pad_img_to_modulo(batch['mask'], predict_config.dataset.pad_out_to_modulo))[None].float().to(device)


    if refinement_enabled is True:
        cur_res = refine_predict(batch, inpaint_model, **predict_config.refiner)
        cur_res = cur_res[0].permute(1,2,0).detach().cpu().numpy()
    else:
        with torch.no_grad():
            batch = move_to_device(batch, device)
            batch['mask'] = (batch['mask'] > 0) * 1
            batch = inpaint_model(batch)
            cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy()
            unpad_to_size = batch.get('unpad_to_size', None)
            if unpad_to_size is not None:
                orig_height, orig_width = unpad_to_size
                cur_res = cur_res[:orig_height, :orig_width]

    cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
    output_image = Image.fromarray(cur_res)

    return output_image

def ceil_modulo(x, mod):
    if x % mod == 0:
        return x
    return (x // mod + 1) * mod

def pad_img_to_modulo(img, mod):
    channels, height, width = img.shape
    out_height = ceil_modulo(height, mod)
    out_width = ceil_modulo(width, mod)
    return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode='symmetric')

def pil_to_binary_mask(pil_image, threshold=0, max_scale=1):
    """
    Converts a PIL image to a binary mask.

    Args:
        pil_image (PIL.Image): The input PIL image.
        threshold (int, optional): The threshold value for binarization. Defaults to 0.

    Returns:
        PIL.Image: A grayscale PIL image representing the binary mask.
    """
    np_image = np.array(pil_image)
    grayscale_image = Image.fromarray(np_image).convert("L")
    binary_mask = np.array(grayscale_image) > threshold
    mask = np.zeros(binary_mask.shape, dtype=np.uint8)
    for i in range(binary_mask.shape[0]):
        for j in range(binary_mask.shape[1]):
            if binary_mask[i,j] == True :
                mask[i,j] = 1
    mask = (mask*max_scale).astype(np.uint8)
    output_mask = Image.fromarray(mask)
    # Convert mask to grayscale
    return output_mask.convert("L")

css = ".output-image, .input-image, .image-preview {height: 600px !important}"

# Create Gradio interface
with gr.Blocks(css=css) as demo:
    gr.Markdown("# Image Inpainting")
    gr.Markdown("Upload an image and draw a mask to remove unwanted objects.")
    
    with gr.Row():
        input_image = gr.ImageEditor(type="pil", label='Input image & Mask', interactive=True, height="auto", width="auto", brush=gr.Brush(colors=['#f2e2cd'], default_size=25))
        output_image = gr.Image(type="pil", label="Output Image", height="auto", width="auto")
    
    with gr.Row():
        refine_checkbox = gr.Checkbox(label="Enable Refinement[SLOWER BUT BETTER]", value=False)
        inpaint_button = gr.Button("Inpaint")

    def inpaint_with_refinement(image, enable_refinement):
        return inpaint(image, refinement_enabled=enable_refinement)

    inpaint_button.click(
        fn=inpaint_with_refinement,
        inputs=[input_image, refine_checkbox],
        outputs=[output_image]
    )

# Launch the interface
if __name__ == "__main__":
    demo.launch()