#!/usr/bin/env python

from __future__ import annotations

import os
import random

import gradio as gr
import numpy as np
import spaces
import requests
import torch
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline, AutoencoderKL, DiffusionPipeline
from diffusers.utils import load_image
from safety_checker import StableDiffusionSafetyChecker

DESCRIPTION = "# SDXL"
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"

MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1824"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
ENABLE_USE_LORA = os.getenv("ENABLE_USE_LORA", "1") == "1"
ENABLE_USE_VAE = os.getenv("ENABLE_USE_VAE", "1") == "1"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


@spaces.GPU
def generate(
    prompt: str,
    negative_prompt: str = "",
    prompt_2: str = "",
    negative_prompt_2: str = "",
    use_negative_prompt: bool = False,
    use_prompt_2: bool = False,
    use_negative_prompt_2: bool = False,
    seed: int = 0,
    width: int = 1024,
    height: int = 1024,
    guidance_scale_base: float = 5.0,
    guidance_scale_refiner: float = 5.0,
    num_inference_steps_base: int = 25,
    num_inference_steps_refiner: int = 25,
    use_vae: bool = False,
    use_lora: bool = False,
    apply_refiner: bool = False,
    model = 'SG161222/Realistic_Vision_V6.0_B1_noVAE',
    vaecall = 'stabilityai/sd-vae-ft-mse',
    lora = 'amazonaws-la/juliette',
    url = "https://m.media-amazon.com/images/I/81zPcrN6m+L.jpg",
    lora_scale: float = 0.7,
):
    if torch.cuda.is_available():

        if not use_vae:
            safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
            pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model, torch_dtype=torch.float16)
            
        if use_vae:
            vae = AutoencoderKL.from_pretrained(vaecall, torch_dtype=torch.float16)
            pipe = DiffusionPipeline.from_pretrained(model, vae=vae, torch_dtype=torch.float16)
            
        if use_lora:
            pipe.load_lora_weights(lora)
            pipe.fuse_lora(lora_scale=0.7)
            
        response = requests.get(url)
        init_image = Image.open(BytesIO(response.content)).convert("RGB")
        init_image = init_image.resize((1024, 1024))

        if ENABLE_CPU_OFFLOAD:
            pipe.enable_model_cpu_offload()
            
        else:
            pipe.to(device)

        if USE_TORCH_COMPILE:
            pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
        
    generator = torch.Generator().manual_seed(seed)

    if not use_negative_prompt:
        negative_prompt = None  # type: ignore
    if not use_prompt_2:
        prompt_2 = None  # type: ignore
    if not use_negative_prompt_2:
        negative_prompt_2 = None  # type: ignore

    if not apply_refiner:
        return pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_2=prompt_2,
            negative_prompt_2=negative_prompt_2,
            width=width,
            height=height,
            guidance_scale=guidance_scale_base,
            num_inference_steps=num_inference_steps_base,
            generator=generator,
            image=init_image,
            output_type="pil",
        ).images[0]
    else:
        latents = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_2=prompt_2,
            negative_prompt_2=negative_prompt_2,
            width=width,
            height=height,
            guidance_scale=guidance_scale_base,
            num_inference_steps=num_inference_steps_base,
            generator=generator,
            output_type="latent",
        ).images
        image = refiner(
            prompt=prompt,
            negative_prompt=negative_prompt,
            prompt_2=prompt_2,
            negative_prompt_2=negative_prompt_2,
            guidance_scale=guidance_scale_refiner,
            num_inference_steps=num_inference_steps_refiner,
            image=latents,
            generator=generator,
        ).images[0]
        return image


examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
]

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        model = gr.Text(label='Modelo')
        vaecall = gr.Text(label='VAE')
        lora = gr.Text(label='LoRA')
        lora_scale = gr.Slider(
                label="Lora Scale",
                minimum=0.01,
                maximum=1,
                step=0.01,
                value=0.7,
            )
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Image(label="Result", show_label=False)
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
            use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
            use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
        negative_prompt = gr.Text(
            label="Negative prompt",
            max_lines=1,
            placeholder="Enter a negative prompt",
            visible=False,
        )
        prompt_2 = gr.Text(
            label="Prompt 2",
            max_lines=1,
            placeholder="Enter your prompt",
            visible=False,
        )
        negative_prompt_2 = gr.Text(
            label="Negative prompt 2",
            max_lines=1,
            placeholder="Enter a negative prompt",
            visible=False,
        )

        seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=MAX_SEED,
            step=1,
            value=0,
        )
        randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
        with gr.Row():
            width = gr.Slider(
                label="Width",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=1024,
            )
            height = gr.Slider(
                label="Height",
                minimum=256,
                maximum=MAX_IMAGE_SIZE,
                step=32,
                value=1024,
            )
        use_vae = gr.Checkbox(label='Use VAE', value=False, visible=ENABLE_USE_VAE)
        use_lora = gr.Checkbox(label='Use Lora', value=False, visible=ENABLE_USE_LORA)
        apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
        with gr.Row():
            guidance_scale_base = gr.Slider(
                label="Guidance scale for base",
                minimum=1,
                maximum=20,
                step=0.1,
                value=5.0,
            )
            num_inference_steps_base = gr.Slider(
                label="Number of inference steps for base",
                minimum=10,
                maximum=100,
                step=1,
                value=25,
            )
        with gr.Row(visible=False) as refiner_params:
            guidance_scale_refiner = gr.Slider(
                label="Guidance scale for refiner",
                minimum=1,
                maximum=20,
                step=0.1,
                value=5.0,
            )
            num_inference_steps_refiner = gr.Slider(
                label="Number of inference steps for refiner",
                minimum=10,
                maximum=100,
                step=1,
                value=25,
            )

    gr.Examples(
        examples=examples,
        inputs=prompt,
        outputs=result,
        fn=generate,
        cache_examples=CACHE_EXAMPLES,
    )

    use_negative_prompt.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_negative_prompt,
        outputs=negative_prompt,
        queue=False,
        api_name=False,
    )
    use_prompt_2.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_prompt_2,
        outputs=prompt_2,
        queue=False,
        api_name=False,
    )
    use_negative_prompt_2.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_negative_prompt_2,
        outputs=negative_prompt_2,
        queue=False,
        api_name=False,
    )
    use_vae.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_vae,
        outputs=vaecall,
        queue=False,
        api_name=False,
    )
    use_lora.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_lora,
        outputs=lora,
        queue=False,
        api_name=False,
    )
    apply_refiner.change(
        fn=lambda x: gr.update(visible=x),
        inputs=apply_refiner,
        outputs=refiner_params,
        queue=False,
        api_name=False,
    )

    gr.on(
        triggers=[
            prompt.submit,
            negative_prompt.submit,
            prompt_2.submit,
            negative_prompt_2.submit,
            run_button.click,
        ],
        fn=randomize_seed_fn,
        inputs=[seed, randomize_seed],
        outputs=seed,
        queue=False,
        api_name=False,
    ).then(
        fn=generate,
        inputs=[
            prompt,
            negative_prompt,
            prompt_2,
            negative_prompt_2,
            use_negative_prompt,
            use_prompt_2,
            use_negative_prompt_2,
            seed,
            width,
            height,
            guidance_scale_base,
            guidance_scale_refiner,
            num_inference_steps_base,
            num_inference_steps_refiner,
            use_vae,
            use_lora,
            apply_refiner,
            model,
            vaecall,
            lora,
            lora_scale,
        ],
        outputs=result,
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch()