import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
from transformers import CLIPTokenizer

# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

# List of concept embeddings compatible with SD v1.4
concepts = [
    "sd-concepts-library/cat-toy",
    "sd-concepts-library/disco-diffusion-style",
    # "sd-concepts-library/modern-disney-style",
    # "sd-concepts-library/charliebo-artstyle",
    # "sd-concepts-library/redshift-render-style"
]

def download_concept_embedding(concept_name):
    try:
        # Download the learned_embeds.bin file from the Hub
        embed_path = hf_hub_download(
            repo_id=concept_name,
            filename="learned_embeds.bin",
            repo_type="model"
        )
        return embed_path
    except Exception as e:
        print(f"Error downloading {concept_name}: {str(e)}")
        return None

def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
    loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
    
    # Add the concept token to tokenizer
    token = list(loaded_learned_embeds.keys())[0]
    num_added_tokens = tokenizer.add_tokens(token)
    
    # Resize token embeddings
    text_encoder.resize_token_embeddings(len(tokenizer))
    
    # Add the concept embedding
    token_id = tokenizer.convert_tokens_to_ids(token)
    text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
    
    return token

def generate_images(prompt):
    images = []
    failed_concepts = []
    
    for concept in concepts:
        try:
            # Create a fresh pipeline for each concept
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32
            ).to(device)
            
            # Download and load concept embedding
            embed_path = download_concept_embedding(concept)
            if embed_path is None:
                failed_concepts.append(concept)
                continue
                
            token = load_learned_embed_in_clip(
                embed_path,
                pipe.text_encoder,
                pipe.tokenizer
            )
            
            # Generate random seed
            seed = random.randint(1, 999999)
            generator = torch.Generator(device=device).manual_seed(seed)
            
            # Add concept token to prompt
            concept_prompt = f"{token} {prompt}"
            
            # Generate image
            with autocast(device):
                image = pipe(
                    concept_prompt,
                    num_inference_steps=20,
                    generator=generator,
                    guidance_scale=7.5
                ).images[0]
            
            images.append(image)
            
            # Clean up to free memory
            del pipe
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error processing concept {concept}: {str(e)}")
            failed_concepts.append(concept)
            continue
    
    if failed_concepts:
        print(f"Failed to process concepts: {', '.join(failed_concepts)}")
    
    # Return available images, pad with None if some failed
    while len(images) < 5:
        images.append(None)
    return images[:5]

# Create Gradio interface
iface = gr.Interface(
    fn=generate_images,
    inputs=gr.Textbox(label="Enter your prompt"),
    outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
    title="Multi-Concept Stable Diffusion Generator",
    description="Generate images using 5 different artistic concepts from the SD Concepts Library"
)

# Launch the app
iface.launch()