satyanayak's picture
fresh pipeline for each concept
d6f9ffa
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
from transformers import CLIPTokenizer
# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# List of concept embeddings compatible with SD v1.4
concepts = [
"sd-concepts-library/cat-toy",
"sd-concepts-library/disco-diffusion-style",
# "sd-concepts-library/modern-disney-style",
# "sd-concepts-library/charliebo-artstyle",
# "sd-concepts-library/redshift-render-style"
]
def download_concept_embedding(concept_name):
try:
# Download the learned_embeds.bin file from the Hub
embed_path = hf_hub_download(
repo_id=concept_name,
filename="learned_embeds.bin",
repo_type="model"
)
return embed_path
except Exception as e:
print(f"Error downloading {concept_name}: {str(e)}")
return None
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# Add the concept token to tokenizer
token = list(loaded_learned_embeds.keys())[0]
num_added_tokens = tokenizer.add_tokens(token)
# Resize token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# Add the concept embedding
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
return token
def generate_images(prompt):
images = []
failed_concepts = []
for concept in concepts:
try:
# Create a fresh pipeline for each concept
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Download and load concept embedding
embed_path = download_concept_embedding(concept)
if embed_path is None:
failed_concepts.append(concept)
continue
token = load_learned_embed_in_clip(
embed_path,
pipe.text_encoder,
pipe.tokenizer
)
# Generate random seed
seed = random.randint(1, 999999)
generator = torch.Generator(device=device).manual_seed(seed)
# Add concept token to prompt
concept_prompt = f"{token} {prompt}"
# Generate image
with autocast(device):
image = pipe(
concept_prompt,
num_inference_steps=20,
generator=generator,
guidance_scale=7.5
).images[0]
images.append(image)
# Clean up to free memory
del pipe
torch.cuda.empty_cache()
except Exception as e:
print(f"Error processing concept {concept}: {str(e)}")
failed_concepts.append(concept)
continue
if failed_concepts:
print(f"Failed to process concepts: {', '.join(failed_concepts)}")
# Return available images, pad with None if some failed
while len(images) < 5:
images.append(None)
return images[:5]
# Create Gradio interface
iface = gr.Interface(
fn=generate_images,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
title="Multi-Concept Stable Diffusion Generator",
description="Generate images using 5 different artistic concepts from the SD Concepts Library"
)
# Launch the app
iface.launch()