File size: 7,113 Bytes
6ba84d9
3f1893b
6ba84d9
3f1893b
3f9ea51
d52abc5
6ba84d9
 
f0be263
 
 
3f1893b
6ba84d9
 
 
 
 
 
 
 
 
 
 
 
f0be263
 
 
 
 
 
 
 
 
 
6ba84d9
 
 
 
 
f0be263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ba84d9
 
 
 
d52abc5
6ba84d9
 
f0be263
 
6ba84d9
 
 
 
 
f0be263
6ba84d9
 
 
 
 
 
f0be263
 
 
6ba84d9
f0be263
3f1893b
 
6ba84d9
 
 
f0be263
6ba84d9
 
 
 
 
 
 
 
 
 
 
 
3f1893b
 
6ba84d9
 
 
 
 
3f1893b
 
 
 
 
 
f0be263
 
3f1893b
6ba84d9
 
3f1893b
6ba84d9
3f1893b
6ba84d9
 
3f1893b
6ba84d9
 
 
 
3f1893b
 
f0be263
 
 
 
 
30023e3
f0be263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ba84d9
 
 
f0be263
 
6ba84d9
3f1893b
6ba84d9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import spaces
import os
import json
import time
import torch #a
from PIL import Image
from tqdm import tqdm
import gradio as gr
import uuid
from datetime import datetime
from typing import List, Dict, Optional

from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora

# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"    
lora_base_path = "./models"

pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer

# Gallery storage
GALLERY_DIR = "gallery"
os.makedirs(GALLERY_DIR, exist_ok=True)
GALLERY_DB = os.path.join(GALLERY_DIR, "gallery_db.json")

# Initialize gallery database
if not os.path.exists(GALLERY_DB):
    with open(GALLERY_DB, "w") as f:
        json.dump({"images": []}, f)

def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

def add_to_gallery(image: Image.Image, prompt: str, control_type: str) -> str:
    """Save image to gallery and return its path"""
    image_id = str(uuid.uuid4())
    filename = f"{image_id}.png"
    filepath = os.path.join(GALLERY_DIR, filename)
    
    image.save(filepath)
    
    # Update gallery database
    with open(GALLERY_DB, "r") as f:
        db = json.load(f)
    
    db["images"].append({
        "id": image_id,
        "filename": filename,
        "prompt": prompt,
        "control_type": control_type,
        "created_at": datetime.now().isoformat()
    })
    
    with open(GALLERY_DB, "w") as f:
        json.dump(db, f, indent=2)
    
    return filepath

def get_gallery_images() -> List[Dict]:
    """Get all gallery images from database"""
    try:
        with open(GALLERY_DB, "r") as f:
            db = json.load(f)
        return db["images"]
    except:
        return []

def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, progress=gr.Progress()):
    # Set the control type
    if control_type == "Ghibli":
        lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
    set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512, device="cpu")
    
    # Process the image
    spatial_imgs = [spatial_img] if spatial_img else []
    
    progress(0, desc="Starting generation...")
    image = pipe(
        prompt,
        height=int(height),
        width=int(width),
        guidance_scale=3.5,
        num_inference_steps=15,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed), 
        subject_images=[],
        spatial_images=spatial_imgs,
        cond_size=512,
    ).images[0]
    
    # Save to gallery
    image_path = add_to_gallery(image, prompt, control_type)
    clear_cache(pipe.transformer)
    
    return image

# Define the Gradio interface components
control_types = ["Ghibli"]

# Example data
single_examples = [
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 512, 512, 5, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 512, 512, 42, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 512, 512, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 512, 512, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 512, 512, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 512, 512, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 512, 512, 1, "Ghibli"],
    ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 512, 512, 1, "Ghibli"],
]

# Create the Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
    gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
    gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Running on CPU due to free tier limitations; expect slower performance and lower resolution.)")
    
    gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
    gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")

    with gr.Tab("Ghibli Condition Generation"):
        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
                spatial_img = gr.Image(label="Ghibli Image", type="pil")
                height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512)
                width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512)
                seed = gr.Number(label="Seed", value=42)
                control_type = gr.Dropdown(choices=control_types, label="Control Type")
                single_generate_btn = gr.Button("Generate Image")
            with gr.Column():
                single_output_image = gr.Image(label="Generated Image")

        gr.Examples(
            examples=single_examples,
            inputs=[prompt, spatial_img, height, width, seed, control_type],
            outputs=single_output_image,
            fn=single_condition_generate_image,
            cache_examples=False,
            label="Single Condition Examples"
        )

    with gr.Tab("Gallery"):
        gallery = gr.Gallery(
            label="Generated Images",
            show_label=True,
            elem_id="gallery"
        )
        
        refresh_btn = gr.Button("Refresh Gallery")
        
        def load_gallery():
            images = get_gallery_images()
            return [os.path.join(GALLERY_DIR, img["filename"]) for img in images]
        
        refresh_btn.click(
            fn=load_gallery,
            outputs=gallery
        )
        
        # Load gallery on page load
        demo.load(
            fn=load_gallery,
            outputs=gallery
        )

    single_generate_btn.click(
        single_condition_generate_image,
        inputs=[prompt, spatial_img, height, width, seed, control_type],
        outputs=single_output_image,
        concurrency_limit=1  # Process one at a time
    )

# Launch the Gradio app
demo.queue().launch()