Spaces:
Running
Running
import gradio as gr | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
import random | |
from huggingface_hub import hf_hub_download | |
import os | |
from transformers import CLIPTokenizer | |
# Initialize the model | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# List of concept embeddings compatible with SD v1.4 | |
concepts = [ | |
"sd-concepts-library/cat-toy", | |
"sd-concepts-library/disco-diffusion-style", | |
# "sd-concepts-library/modern-disney-style", | |
# "sd-concepts-library/charliebo-artstyle", | |
# "sd-concepts-library/redshift-render-style" | |
] | |
def download_concept_embedding(concept_name): | |
try: | |
# Download the learned_embeds.bin file from the Hub | |
embed_path = hf_hub_download( | |
repo_id=concept_name, | |
filename="learned_embeds.bin", | |
repo_type="model" | |
) | |
return embed_path | |
except Exception as e: | |
print(f"Error downloading {concept_name}: {str(e)}") | |
return None | |
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer): | |
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
# Add the concept token to tokenizer | |
token = list(loaded_learned_embeds.keys())[0] | |
num_added_tokens = tokenizer.add_tokens(token) | |
# Resize token embeddings | |
text_encoder.resize_token_embeddings(len(tokenizer)) | |
# Add the concept embedding | |
token_id = tokenizer.convert_tokens_to_ids(token) | |
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token] | |
return token | |
def generate_images(prompt): | |
images = [] | |
failed_concepts = [] | |
for concept in concepts: | |
try: | |
# Create a fresh pipeline for each concept | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
# Download and load concept embedding | |
embed_path = download_concept_embedding(concept) | |
if embed_path is None: | |
failed_concepts.append(concept) | |
continue | |
token = load_learned_embed_in_clip( | |
embed_path, | |
pipe.text_encoder, | |
pipe.tokenizer | |
) | |
# Generate random seed | |
seed = random.randint(1, 999999) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# Add concept token to prompt | |
concept_prompt = f"{token} {prompt}" | |
# Generate image | |
with autocast(device): | |
image = pipe( | |
concept_prompt, | |
num_inference_steps=20, | |
generator=generator, | |
guidance_scale=7.5 | |
).images[0] | |
images.append(image) | |
# Clean up to free memory | |
del pipe | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"Error processing concept {concept}: {str(e)}") | |
failed_concepts.append(concept) | |
continue | |
if failed_concepts: | |
print(f"Failed to process concepts: {', '.join(failed_concepts)}") | |
# Return available images, pad with None if some failed | |
while len(images) < 5: | |
images.append(None) | |
return images[:5] | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_images, | |
inputs=gr.Textbox(label="Enter your prompt"), | |
outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)], | |
title="Multi-Concept Stable Diffusion Generator", | |
description="Generate images using 5 different artistic concepts from the SD Concepts Library" | |
) | |
# Launch the app | |
iface.launch() |