Spaces:
Sleeping
Sleeping
File size: 7,113 Bytes
6ba84d9 3f1893b 6ba84d9 3f1893b 3f9ea51 d52abc5 6ba84d9 f0be263 3f1893b 6ba84d9 f0be263 6ba84d9 f0be263 6ba84d9 d52abc5 6ba84d9 f0be263 6ba84d9 f0be263 6ba84d9 f0be263 6ba84d9 f0be263 3f1893b 6ba84d9 f0be263 6ba84d9 3f1893b 6ba84d9 3f1893b f0be263 3f1893b 6ba84d9 3f1893b 6ba84d9 3f1893b 6ba84d9 3f1893b 6ba84d9 3f1893b f0be263 30023e3 f0be263 6ba84d9 f0be263 6ba84d9 3f1893b 6ba84d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import spaces
import os
import json
import time
import torch #a
from PIL import Image
from tqdm import tqdm
import gradio as gr
import uuid
from datetime import datetime
from typing import List, Dict, Optional
from safetensors.torch import save_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
# Initialize the image processor
base_path = "black-forest-labs/FLUX.1-dev"
lora_base_path = "./models"
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe.transformer = transformer
# Gallery storage
GALLERY_DIR = "gallery"
os.makedirs(GALLERY_DIR, exist_ok=True)
GALLERY_DB = os.path.join(GALLERY_DIR, "gallery_db.json")
# Initialize gallery database
if not os.path.exists(GALLERY_DB):
with open(GALLERY_DB, "w") as f:
json.dump({"images": []}, f)
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
def add_to_gallery(image: Image.Image, prompt: str, control_type: str) -> str:
"""Save image to gallery and return its path"""
image_id = str(uuid.uuid4())
filename = f"{image_id}.png"
filepath = os.path.join(GALLERY_DIR, filename)
image.save(filepath)
# Update gallery database
with open(GALLERY_DB, "r") as f:
db = json.load(f)
db["images"].append({
"id": image_id,
"filename": filename,
"prompt": prompt,
"control_type": control_type,
"created_at": datetime.now().isoformat()
})
with open(GALLERY_DB, "w") as f:
json.dump(db, f, indent=2)
return filepath
def get_gallery_images() -> List[Dict]:
"""Get all gallery images from database"""
try:
with open(GALLERY_DB, "r") as f:
db = json.load(f)
return db["images"]
except:
return []
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type, progress=gr.Progress()):
# Set the control type
if control_type == "Ghibli":
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512, device="cpu")
# Process the image
spatial_imgs = [spatial_img] if spatial_img else []
progress(0, desc="Starting generation...")
image = pipe(
prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=15,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
subject_images=[],
spatial_images=spatial_imgs,
cond_size=512,
).images[0]
# Save to gallery
image_path = add_to_gallery(image, prompt, control_type)
clear_cache(pipe.transformer)
return image
# Define the Gradio interface components
control_types = ["Ghibli"]
# Example data
single_examples = [
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 512, 512, 5, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 512, 512, 42, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 512, 512, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 512, 512, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 512, 512, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 512, 512, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 512, 512, 1, "Ghibli"],
["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 512, 512, 1, "Ghibli"],
]
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Running on CPU due to free tier limitations; expect slower performance and lower resolution.)")
gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
with gr.Tab("Ghibli Condition Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
spatial_img = gr.Image(label="Ghibli Image", type="pil")
height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512)
width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512)
seed = gr.Number(label="Seed", value=42)
control_type = gr.Dropdown(choices=control_types, label="Control Type")
single_generate_btn = gr.Button("Generate Image")
with gr.Column():
single_output_image = gr.Image(label="Generated Image")
gr.Examples(
examples=single_examples,
inputs=[prompt, spatial_img, height, width, seed, control_type],
outputs=single_output_image,
fn=single_condition_generate_image,
cache_examples=False,
label="Single Condition Examples"
)
with gr.Tab("Gallery"):
gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery"
)
refresh_btn = gr.Button("Refresh Gallery")
def load_gallery():
images = get_gallery_images()
return [os.path.join(GALLERY_DIR, img["filename"]) for img in images]
refresh_btn.click(
fn=load_gallery,
outputs=gallery
)
# Load gallery on page load
demo.load(
fn=load_gallery,
outputs=gallery
)
single_generate_btn.click(
single_condition_generate_image,
inputs=[prompt, spatial_img, height, width, seed, control_type],
outputs=single_output_image,
concurrency_limit=1 # Process one at a time
)
# Launch the Gradio app
demo.queue().launch() |