Spaces:
Running
Running
File size: 3,934 Bytes
7a41953 6df8da7 d6f9ffa 7a41953 6269d98 7a41953 339f63e f67fd7b 7a41953 6df8da7 7a41953 6269d98 7a41953 6df8da7 d6f9ffa 6269d98 6df8da7 f67fd7b 6df8da7 d6f9ffa 6df8da7 6269d98 6df8da7 7a41953 6269d98 7a41953 6269d98 7a41953 6269d98 7a41953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
from transformers import CLIPTokenizer
# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
# List of concept embeddings compatible with SD v1.4
concepts = [
"sd-concepts-library/cat-toy",
"sd-concepts-library/disco-diffusion-style",
# "sd-concepts-library/modern-disney-style",
# "sd-concepts-library/charliebo-artstyle",
# "sd-concepts-library/redshift-render-style"
]
def download_concept_embedding(concept_name):
try:
# Download the learned_embeds.bin file from the Hub
embed_path = hf_hub_download(
repo_id=concept_name,
filename="learned_embeds.bin",
repo_type="model"
)
return embed_path
except Exception as e:
print(f"Error downloading {concept_name}: {str(e)}")
return None
def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
# Add the concept token to tokenizer
token = list(loaded_learned_embeds.keys())[0]
num_added_tokens = tokenizer.add_tokens(token)
# Resize token embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
# Add the concept embedding
token_id = tokenizer.convert_tokens_to_ids(token)
text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
return token
def generate_images(prompt):
images = []
failed_concepts = []
for concept in concepts:
try:
# Create a fresh pipeline for each concept
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
# Download and load concept embedding
embed_path = download_concept_embedding(concept)
if embed_path is None:
failed_concepts.append(concept)
continue
token = load_learned_embed_in_clip(
embed_path,
pipe.text_encoder,
pipe.tokenizer
)
# Generate random seed
seed = random.randint(1, 999999)
generator = torch.Generator(device=device).manual_seed(seed)
# Add concept token to prompt
concept_prompt = f"{token} {prompt}"
# Generate image
with autocast(device):
image = pipe(
concept_prompt,
num_inference_steps=20,
generator=generator,
guidance_scale=7.5
).images[0]
images.append(image)
# Clean up to free memory
del pipe
torch.cuda.empty_cache()
except Exception as e:
print(f"Error processing concept {concept}: {str(e)}")
failed_concepts.append(concept)
continue
if failed_concepts:
print(f"Failed to process concepts: {', '.join(failed_concepts)}")
# Return available images, pad with None if some failed
while len(images) < 5:
images.append(None)
return images[:5]
# Create Gradio interface
iface = gr.Interface(
fn=generate_images,
inputs=gr.Textbox(label="Enter your prompt"),
outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
title="Multi-Concept Stable Diffusion Generator",
description="Generate images using 5 different artistic concepts from the SD Concepts Library"
)
# Launch the app
iface.launch() |