File size: 3,934 Bytes
7a41953
 
 
 
 
6df8da7
 
d6f9ffa
7a41953
 
 
 
 
6269d98
7a41953
339f63e
 
f67fd7b
 
 
7a41953
 
6df8da7
 
 
 
 
 
 
 
 
 
 
 
 
7a41953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6269d98
7a41953
 
6df8da7
d6f9ffa
 
 
 
 
 
6269d98
 
 
 
 
 
6df8da7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f67fd7b
6df8da7
 
 
 
 
 
d6f9ffa
 
 
6df8da7
 
 
6269d98
6df8da7
7a41953
6269d98
 
 
 
 
 
 
7a41953
 
 
 
 
6269d98
7a41953
6269d98
7a41953
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import random
from huggingface_hub import hf_hub_download
import os
from transformers import CLIPTokenizer

# Initialize the model
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"

# List of concept embeddings compatible with SD v1.4
concepts = [
    "sd-concepts-library/cat-toy",
    "sd-concepts-library/disco-diffusion-style",
    # "sd-concepts-library/modern-disney-style",
    # "sd-concepts-library/charliebo-artstyle",
    # "sd-concepts-library/redshift-render-style"
]

def download_concept_embedding(concept_name):
    try:
        # Download the learned_embeds.bin file from the Hub
        embed_path = hf_hub_download(
            repo_id=concept_name,
            filename="learned_embeds.bin",
            repo_type="model"
        )
        return embed_path
    except Exception as e:
        print(f"Error downloading {concept_name}: {str(e)}")
        return None

def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer):
    loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
    
    # Add the concept token to tokenizer
    token = list(loaded_learned_embeds.keys())[0]
    num_added_tokens = tokenizer.add_tokens(token)
    
    # Resize token embeddings
    text_encoder.resize_token_embeddings(len(tokenizer))
    
    # Add the concept embedding
    token_id = tokenizer.convert_tokens_to_ids(token)
    text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
    
    return token

def generate_images(prompt):
    images = []
    failed_concepts = []
    
    for concept in concepts:
        try:
            # Create a fresh pipeline for each concept
            pipe = StableDiffusionPipeline.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if device == "cuda" else torch.float32
            ).to(device)
            
            # Download and load concept embedding
            embed_path = download_concept_embedding(concept)
            if embed_path is None:
                failed_concepts.append(concept)
                continue
                
            token = load_learned_embed_in_clip(
                embed_path,
                pipe.text_encoder,
                pipe.tokenizer
            )
            
            # Generate random seed
            seed = random.randint(1, 999999)
            generator = torch.Generator(device=device).manual_seed(seed)
            
            # Add concept token to prompt
            concept_prompt = f"{token} {prompt}"
            
            # Generate image
            with autocast(device):
                image = pipe(
                    concept_prompt,
                    num_inference_steps=20,
                    generator=generator,
                    guidance_scale=7.5
                ).images[0]
            
            images.append(image)
            
            # Clean up to free memory
            del pipe
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"Error processing concept {concept}: {str(e)}")
            failed_concepts.append(concept)
            continue
    
    if failed_concepts:
        print(f"Failed to process concepts: {', '.join(failed_concepts)}")
    
    # Return available images, pad with None if some failed
    while len(images) < 5:
        images.append(None)
    return images[:5]

# Create Gradio interface
iface = gr.Interface(
    fn=generate_images,
    inputs=gr.Textbox(label="Enter your prompt"),
    outputs=[gr.Image(label=f"Concept {i+1}") for i in range(5)],
    title="Multi-Concept Stable Diffusion Generator",
    description="Generate images using 5 different artistic concepts from the SD Concepts Library"
)

# Launch the app
iface.launch()