#!/usr/bin/env python

from __future__ import annotations

import os
import random

import gradio as gr
import numpy as np
import PIL.Image
import torch
from diffusers import DiffusionPipeline

DESCRIPTION = 'This space is an API service meant to be used by VideoChain and VideoQuest.\nWant to use this space for yourself? Please use the original code: [https://huggingface.co/spaces/hysts/SD-XL](https://huggingface.co/spaces/hysts/SD-XL)'
if not torch.cuda.is_available():
    DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = int(os.getenv('MAX_IMAGE_SIZE', '1024'))
USE_TORCH_COMPILE = os.getenv('USE_TORCH_COMPILE') == '1'
ENABLE_CPU_OFFLOAD = os.getenv('ENABLE_CPU_OFFLOAD') == '1'
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    pipe = DiffusionPipeline.from_pretrained(
        'stabilityai/stable-diffusion-xl-base-1.0',
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant='fp16')
    refiner = DiffusionPipeline.from_pretrained(
        'stabilityai/stable-diffusion-xl-refiner-1.0',
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant='fp16')

    if ENABLE_CPU_OFFLOAD:
        pipe.enable_model_cpu_offload()
        refiner.enable_model_cpu_offload()
    else:
        pipe.to(device)
        refiner.to(device)

    if USE_TORCH_COMPILE:
        pipe.unet = torch.compile(pipe.unet,
                                  mode='reduce-overhead',
                                  fullgraph=True)
else:
    pipe = None
    refiner = None
    
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    return seed


def generate(prompt: str,
             negative_prompt: str = '',
             prompt_2: str = '',
             negative_prompt_2: str = '',
             use_negative_prompt: bool = False,
             use_prompt_2: bool = False,
             use_negative_prompt_2: bool = False,
             seed: int = 0,
             width: int = 1024,
             height: int = 1024,
             guidance_scale_base: float = 5.0,
             guidance_scale_refiner: float = 5.0,
             num_inference_steps_base: int = 50,
             num_inference_steps_refiner: int = 50,
             apply_refiner: bool = False,
             secret_token: str = '') -> PIL.Image.Image:
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
        
    generator = torch.Generator().manual_seed(seed)

    if not use_negative_prompt:
        negative_prompt = None  # type: ignore
    if not use_prompt_2:
        prompt_2 = None  # type: ignore
    if not use_negative_prompt_2:
        negative_prompt_2 = None  # type: ignore

    if not apply_refiner:
        return pipe(prompt=prompt,
                    negative_prompt=negative_prompt,
                    prompt_2=prompt_2,
                    negative_prompt_2=negative_prompt_2,
                    width=width,
                    height=height,
                    guidance_scale=guidance_scale_base,
                    num_inference_steps=num_inference_steps_base,
                    generator=generator,
                    output_type='pil').images[0]
    else:
        latents = pipe(prompt=prompt,
                       negative_prompt=negative_prompt,
                       prompt_2=prompt_2,
                       negative_prompt_2=negative_prompt_2,
                       width=width,
                       height=height,
                       guidance_scale=guidance_scale_base,
                       num_inference_steps=num_inference_steps_base,
                       generator=generator,
                       output_type='latent').images
        image = refiner(prompt=prompt,
                        negative_prompt=negative_prompt,
                        prompt_2=prompt_2,
                        negative_prompt_2=negative_prompt_2,
                        guidance_scale=guidance_scale_refiner,
                        num_inference_steps=num_inference_steps_refiner,
                        image=latents,
                        generator=generator).images[0]
        return image

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Box():
        with gr.Row():
            secret_token = gr.Text(
                label='Secret Token',
                max_lines=1,
                placeholder='Enter your secret token',
            )
            prompt = gr.Text(
                label='Prompt',
                show_label=False,
                max_lines=1,
                placeholder='Enter your prompt',
                container=False,
            )
            run_button = gr.Button('Run', scale=0)
        result = gr.Image(label='Result', show_label=False)
        with gr.Accordion('Advanced options', open=False):
            with gr.Row():
                use_negative_prompt = gr.Checkbox(label='Use negative prompt',
                                                  value=False)
                use_prompt_2 = gr.Checkbox(label='Use prompt 2', value=False)
                use_negative_prompt_2 = gr.Checkbox(
                    label='Use negative prompt 2', value=False)
            negative_prompt = gr.Text(
                label='Negative prompt',
                max_lines=1,
                placeholder='Enter a negative prompt',
                visible=False,
            )
            prompt_2 = gr.Text(
                label='Prompt 2',
                max_lines=1,
                placeholder='Enter your prompt',
                visible=False,
            )
            negative_prompt_2 = gr.Text(
                label='Negative prompt 2',
                max_lines=1,
                placeholder='Enter a negative prompt',
                visible=False,
            )

            seed = gr.Slider(label='Seed',
                             minimum=0,
                             maximum=MAX_SEED,
                             step=1,
                             value=0)
            randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
            with gr.Row():
                width = gr.Slider(
                    label='Width',
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height = gr.Slider(
                    label='Height',
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            apply_refiner = gr.Checkbox(label='Apply refiner', value=False)
            with gr.Row():
                guidance_scale_base = gr.Slider(
                    label='Guidance scale for base',
                    minimum=1,
                    maximum=20,
                    step=0.1,
                    value=5.0)
                num_inference_steps_base = gr.Slider(
                    label='Number of inference steps for base',
                    minimum=10,
                    maximum=100,
                    step=1,
                    value=50)
            with gr.Row(visible=False) as refiner_params:
                guidance_scale_refiner = gr.Slider(
                    label='Guidance scale for refiner',
                    minimum=1,
                    maximum=20,
                    step=0.1,
                    value=5.0)
                num_inference_steps_refiner = gr.Slider(
                    label='Number of inference steps for refiner',
                    minimum=10,
                    maximum=100,
                    step=1,
                    value=50)

    use_negative_prompt.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_negative_prompt,
        outputs=negative_prompt,
        queue=False,
        api_name=False,
    )
    use_prompt_2.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_prompt_2,
        outputs=prompt_2,
        queue=False,
        api_name=False,
    )
    use_negative_prompt_2.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_negative_prompt_2,
        outputs=negative_prompt_2,
        queue=False,
        api_name=False,
    )
    apply_refiner.change(
        fn=lambda x: gr.update(visible=x),
        inputs=apply_refiner,
        outputs=refiner_params,
        queue=False,
        api_name=False,
    )

    inputs = [
        prompt,
        negative_prompt,
        prompt_2,
        negative_prompt_2,
        use_negative_prompt,
        use_prompt_2,
        use_negative_prompt_2,
        seed,
        width,
        height,
        guidance_scale_base,
        guidance_scale_refiner,
        num_inference_steps_base,
        num_inference_steps_refiner,
        apply_refiner,
        secret_token,
    ]
    prompt.submit(
        fn=randomize_seed_fn,
        inputs=[seed, randomize_seed],
        outputs=seed,
        queue=False,
        api_name=False,
    ).then(
        fn=generate,
        inputs=inputs,
        outputs=result,
        api_name='run',
    )
    negative_prompt.submit(
        fn=randomize_seed_fn,
        inputs=[seed, randomize_seed],
        outputs=seed,
        queue=False,
        api_name=False,
    ).then(
        fn=generate,
        inputs=inputs,
        outputs=result,
        api_name=False,
    )
    run_button.click(
        fn=randomize_seed_fn,
        inputs=[seed, randomize_seed],
        outputs=seed,
        queue=False,
        api_name=False,
    ).then(
        fn=generate,
        inputs=inputs,
        outputs=result,
        api_name=False,
    )
demo.queue(max_size=6).launch()