ar08 commited on
Commit
d52abc5
·
verified ·
1 Parent(s): 9522bfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -79
app.py CHANGED
@@ -1,106 +1,90 @@
1
- import spaces
2
  import os
3
- import json
4
  import time
5
- import torch
6
- from PIL import Image
7
- from tqdm import tqdm
8
  import gradio as gr
 
9
 
10
- from safetensors.torch import save_file
11
- from src.pipeline import FluxPipeline
12
- from src.transformer_flux import FluxTransformer2DModel
13
- from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
 
15
- # Initialize the image processor
16
- base_path = "black-forest-labs/FLUX.1-dev"
17
- lora_base_path = "./models"
18
 
19
- pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
20
- transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
21
- pipe.transformer = transformer
22
- # 移除 pipe.to("cuda"),默认使用CPU
 
 
 
 
23
 
24
- def clear_cache(transformer):
25
- for name, attn_processor in transformer.attn_processors.items():
26
- attn_processor.bank_kv.clear()
27
 
28
- # Define the Gradio interface
29
- @spaces.GPU() # 改为 @spaces.CPU() 或直接移除,因为免费层没有GPU
30
- def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
31
- # Set the control type
32
- if control_type == "Ghibli":
33
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
34
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512, device="cpu")
35
-
36
- # Process the image
37
- spatial_imgs = [spatial_img] if spatial_img else []
38
- image = pipe(
39
- prompt,
40
- height=int(height),
41
- width=int(width),
42
- guidance_scale=3.5,
43
- num_inference_steps=15, # 减少步数以适应CPU
44
- max_sequence_length=512,
45
- generator=torch.Generator("cpu").manual_seed(seed),
46
- subject_images=[],
47
- spatial_images=spatial_imgs,
48
- cond_size=512,
49
- ).images[0]
50
- clear_cache(pipe.transformer)
51
  return image
52
 
53
- # Define the Gradio interface components
54
- control_types = ["Ghibli"]
 
 
 
 
 
55
 
56
- # Example data (调整分辨率以适应CPU)
57
- single_examples = [
58
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 512, 512, 5, "Ghibli"],
59
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 512, 512, 42, "Ghibli"],
60
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 512, 512, 1, "Ghibli"],
61
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 512, 512, 1, "Ghibli"],
62
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 512, 512, 1, "Ghibli"],
63
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 512, 512, 1, "Ghibli"],
64
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 512, 512, 1, "Ghibli"],
65
- ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 512, 512, 1, "Ghibli"],
66
- ]
67
 
68
- # Create the Gradio Blocks interface
69
  with gr.Blocks() as demo:
70
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
71
- gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
72
- gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Running on CPU due to free tier limitations; expect slower performance and lower resolution.)")
73
-
74
- gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
75
- gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
76
 
77
  with gr.Tab("Ghibli Condition Generation"):
78
  with gr.Row():
79
  with gr.Column():
80
  prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
81
  spatial_img = gr.Image(label="Ghibli Image", type="pil")
82
- height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512) # 限制最大分辨率
83
- width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512) # 限制最大分辨率
84
  seed = gr.Number(label="Seed", value=42)
85
- control_type = gr.Dropdown(choices=control_types, label="Control Type")
86
- single_generate_btn = gr.Button("Generate Image")
87
  with gr.Column():
88
- single_output_image = gr.Image(label="Generated Image")
89
 
90
- gr.Examples(
91
- examples=single_examples,
92
  inputs=[prompt, spatial_img, height, width, seed, control_type],
93
- outputs=single_output_image,
94
- fn=single_condition_generate_image,
95
- cache_examples=False,
96
- label="Single Condition Examples"
97
  )
98
 
99
- single_generate_btn.click(
100
- single_condition_generate_image,
101
- inputs=[prompt, spatial_img, height, width, seed, control_type],
102
- outputs=single_output_image
103
- )
 
 
 
 
 
 
 
 
 
104
 
105
- # Launch the Gradio app
106
- demo.queue().launch()
 
 
1
  import os
 
2
  import time
3
+ from datetime import datetime
4
+ import uuid
 
5
  import gradio as gr
6
+ from PIL import Image
7
 
8
+ # Define folders
9
+ gallery_folder = "./gallery"
10
+ task_folder = "./tasks"
 
11
 
12
+ os.makedirs(gallery_folder, exist_ok=True)
13
+ os.makedirs(task_folder, exist_ok=True)
 
14
 
15
+ # Function to save image and metadata
16
+ def save_image_to_gallery(image, prompt, seed):
17
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
18
+ filename = f"{timestamp}_seed{seed}.png"
19
+ metadata_filename = filename.replace(".png", ".txt")
20
+
21
+ image_path = os.path.join(gallery_folder, filename)
22
+ metadata_path = os.path.join(gallery_folder, metadata_filename)
23
 
24
+ image.save(image_path)
25
+ with open(metadata_path, "w") as f:
26
+ f.write(f"Prompt: {prompt}\nSeed: {seed}\nTime: {timestamp}")
27
 
28
+ # Modified generation function with saving
29
+ def single_condition_generate_image_with_save(prompt, spatial_img, height, width, seed, control_type):
30
+ image = single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type)
31
+ save_image_to_gallery(image, prompt, seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  return image
33
 
34
+ # Get gallery images
35
+ def load_gallery():
36
+ images = []
37
+ for fname in sorted(os.listdir(gallery_folder), reverse=True):
38
+ if fname.endswith(".png"):
39
+ images.append(os.path.join(gallery_folder, fname))
40
+ return images
41
 
42
+ # Save user task
43
+ def submit_task(prompt, image):
44
+ task_id = str(uuid.uuid4())
45
+ task_img_path = os.path.join(task_folder, f"{task_id}.png")
46
+ task_txt_path = os.path.join(task_folder, f"{task_id}.txt")
47
+ image.save(task_img_path)
48
+ with open(task_txt_path, "w") as f:
49
+ f.write(prompt)
50
+ return f" Task submitted successfully with ID: {task_id}"
 
 
51
 
52
+ # ---- GUI ----
53
  with gr.Blocks() as demo:
54
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
 
 
 
 
 
55
 
56
  with gr.Tab("Ghibli Condition Generation"):
57
  with gr.Row():
58
  with gr.Column():
59
  prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
60
  spatial_img = gr.Image(label="Ghibli Image", type="pil")
61
+ height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512)
62
+ width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512)
63
  seed = gr.Number(label="Seed", value=42)
64
+ control_type = gr.Dropdown(choices=["Ghibli"], label="Control Type")
65
+ generate_btn = gr.Button("Generate Image")
66
  with gr.Column():
67
+ output_img = gr.Image(label="Generated Image")
68
 
69
+ generate_btn.click(
70
+ single_condition_generate_image_with_save,
71
  inputs=[prompt, spatial_img, height, width, seed, control_type],
72
+ outputs=output_img
 
 
 
73
  )
74
 
75
+ with gr.Tab("Gallery"):
76
+ gallery_gallery = gr.Gallery(label="Previous Generations", show_label=False, elem_id="gallery").style(grid=3)
77
+
78
+ gallery_btn = gr.Button("🔄 Refresh Gallery")
79
+ gallery_btn.click(fn=load_gallery, outputs=gallery_gallery)
80
+
81
+ with gr.Tab("Tasks"):
82
+ with gr.Row():
83
+ task_prompt = gr.Textbox(label="Your Prompt")
84
+ task_image = gr.Image(label="Your Image", type="pil")
85
+ task_submit_btn = gr.Button("Submit Task")
86
+ task_submit_output = gr.Textbox(label="Submission Result")
87
+
88
+ task_submit_btn.click(fn=submit_task, inputs=[task_prompt, task_image], outputs=task_submit_output)
89
 
90
+ demo.queue().launch()