import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces
import os
from PIL import Image, ImageFilter
from typing import List, Tuple

SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"

# Constants
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
checkpoints = {
    "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
    "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
    "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
    "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
}
aspect_ratios = {
    "21:9": (21, 9),
    "2:1": (2, 1),
    "16:9": (16, 9),
    "5:4": (5, 4),
    "4:3": (4, 3),
    "3:2": (3, 2),
    "1:1": (1, 1),
}
# Function to calculate resolution
def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
    if aspect_ratio not in aspect_ratios:
        raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")

    width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
    ratio = width_multiplier / height_multiplier
    if mode == 'portrait':
        # Swap the ratio for portrait mode
        ratio = 1 / ratio

    height = int((total_pixels / ratio) ** 0.5)
    height -= height % divisibility

    width = int(height * ratio)
    width -= width % divisibility

    while width * height > total_pixels:
        height -= divisibility
        width = int(height * ratio)
        width -= width % divisibility

    return width, height


# Example prompts with ckpt, aspect, and mode
examples = [
    {"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
    {"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
    {"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
    {"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},    
    {"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
    {"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
]
# Define a function to set the example inputs
def set_example(selected_prompt):
    # Find the example that matches the selected prompt
    for example in examples:
        if example["prompt"] == selected_prompt:
            return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"]
    return None, None, None, None, None  # Default values if not found

# Check if CUDA is available (GPU support), and set the appropriate device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pipeline for the specified device
# For GPU, use torch_dtype=torch.float16 for better performance
if device == "cuda":
    pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device)
else:
    pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device)

if SAFETY_CHECKER:
    from safety_checker import StableDiffusionSafetyChecker
    from transformers import CLIPFeatureExtractor

    safety_checker = StableDiffusionSafetyChecker.from_pretrained(
        "CompVis/stable-diffusion-safety-checker"
    ).to(device)
    feature_extractor = CLIPFeatureExtractor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )

def check_nsfw_images(
    images: List[Image.Image]
) -> Tuple[List[Image.Image], List[bool]]:
    # Assuming feature_extractor and safety_checker are defined and initialized elsewhere

    # Convert PIL Images to the format expected by the feature extractor
    # This often involves converting them to tensors, but the exact method 
    # depends on the feature_extractor's requirements
    safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]

    # Get NSFW concepts for each image
    has_nsfw_concepts = [safety_checker(
        images=[image],
        clip_input=safety_checker_input.pixel_values.to("cuda")
    ) for image, safety_checker_input in zip(images, safety_checker_inputs)]

    # Flatten the has_nsfw_concepts list if it's nested
    has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]

    return images, has_nsfw_concepts

# Function 
@spaces.GPU(enable_queue=True)
def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode):
    width, height = calculate_resolution(aspect_ratio, mode)  # Calculate resolution based on the aspect ratio
    checkpoint = checkpoints[ckpt][0]
    num_inference_steps = checkpoints[ckpt][1] 

    if num_inference_steps==1:
        # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
    else:
        # Ensure sampler uses "trailing" timesteps.
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
        
    pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device))
    results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )

    if SAFETY_CHECKER:
        images, has_nsfw_concepts = check_nsfw_images(results.images)
        if any(has_nsfw_concepts):
            gr.Warning("NSFW content detected.")
            # Apply a blur filter to the first image in the results
            blurred_image = images[0].filter(ImageFilter.GaussianBlur(16))  # Adjust the radius as needed
            return blurred_image
        return images[0]
    return results.images[0]



# Gradio Interface
description = """
SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
"""

with gr.Blocks(css="style.css") as demo:
    gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
    gr.Markdown(description)
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
        with gr.Row():
            negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8)
        with gr.Row():            
            ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
            aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
            mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape')  # Mode as a dropdown
            submit = gr.Button(scale=1, variant='primary')
            
    img = gr.Image(label='SDXL-Lightning Generated Image')

    prompt.submit(fn=generate_image,
                 inputs=[prompt, negative_prompt, ckpt, aspect, mode],
                 outputs=img,
                 )
    submit.click(fn=generate_image,
                 inputs=[prompt, negative_prompt, ckpt, aspect, mode],
                 outputs=img,
                 )
    # Dropdown for selecting examples
    example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
    example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode])

demo.queue().launch()