ar08 commited on
Commit
6ba84d9
·
verified ·
1 Parent(s): e504971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -69
app.py CHANGED
@@ -1,91 +1,106 @@
 
1
  import os
 
2
  import time
3
- from datetime import datetime
4
- import uuid
5
- import gradio as gr
6
  from PIL import Image
 
 
7
 
8
- # Define folders
9
- gallery_folder = "./gallery"
10
- task_folder = "./tasks"
11
-
12
- os.makedirs(gallery_folder, exist_ok=True)
13
- os.makedirs(task_folder, exist_ok=True)
14
-
15
- # Function to save image and metadata
16
- def save_image_to_gallery(image, prompt, seed):
17
- timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
18
- filename = f"{timestamp}_seed{seed}.png"
19
- metadata_filename = filename.replace(".png", ".txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- image_path = os.path.join(gallery_folder, filename)
22
- metadata_path = os.path.join(gallery_folder, metadata_filename)
23
-
24
- image.save(image_path)
25
- with open(metadata_path, "w") as f:
26
- f.write(f"Prompt: {prompt}\nSeed: {seed}\nTime: {timestamp}")
27
-
28
- # Modified generation function with saving
29
- def single_condition_generate_image_with_save(prompt, spatial_img, height, width, seed, control_type):
30
- image = single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type)
31
- save_image_to_gallery(image, prompt, seed)
 
 
 
 
32
  return image
33
 
34
- # Get gallery images
35
- def load_gallery():
36
- images = []
37
- for fname in sorted(os.listdir(gallery_folder), reverse=True):
38
- if fname.endswith(".png"):
39
- images.append(os.path.join(gallery_folder, fname))
40
- return images
41
-
42
- # Save user task
43
- def submit_task(prompt, image):
44
- task_id = str(uuid.uuid4())
45
- task_img_path = os.path.join(task_folder, f"{task_id}.png")
46
- task_txt_path = os.path.join(task_folder, f"{task_id}.txt")
47
- image.save(task_img_path)
48
- with open(task_txt_path, "w") as f:
49
- f.write(prompt)
50
- return f"✅ Task submitted successfully with ID: {task_id}"
51
-
52
- # ---- GUI ----
53
  with gr.Blocks() as demo:
54
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
 
 
 
 
 
55
 
56
  with gr.Tab("Ghibli Condition Generation"):
57
  with gr.Row():
58
  with gr.Column():
59
  prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
60
  spatial_img = gr.Image(label="Ghibli Image", type="pil")
61
- height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512)
62
- width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512)
63
  seed = gr.Number(label="Seed", value=42)
64
- control_type = gr.Dropdown(choices=["Ghibli"], label="Control Type")
65
- generate_btn = gr.Button("Generate Image")
66
  with gr.Column():
67
- output_img = gr.Image(label="Generated Image")
68
 
69
- generate_btn.click(
70
- single_condition_generate_image_with_save,
71
  inputs=[prompt, spatial_img, height, width, seed, control_type],
72
- outputs=output_img
 
 
 
73
  )
74
 
75
- with gr.Tab("Gallery"):
76
- gallery_gallery = gr.Gallery(label="Previous Generations", show_label=False, elem_id="gallery", columns=[3], rows=[3])
77
-
78
-
79
- gallery_btn = gr.Button("🔄 Refresh Gallery")
80
- gallery_btn.click(fn=load_gallery, outputs=gallery_gallery)
81
-
82
- with gr.Tab("Tasks"):
83
- with gr.Row():
84
- task_prompt = gr.Textbox(label="Your Prompt")
85
- task_image = gr.Image(label="Your Image", type="pil")
86
- task_submit_btn = gr.Button("Submit Task")
87
- task_submit_output = gr.Textbox(label="Submission Result")
88
-
89
- task_submit_btn.click(fn=submit_task, inputs=[task_prompt, task_image], outputs=task_submit_output)
90
 
91
- demo.queue().launch()
 
 
1
+ import spaces
2
  import os
3
+ import json
4
  import time
5
+ import torch
 
 
6
  from PIL import Image
7
+ from tqdm import tqdm
8
+ import gradio as gr
9
 
10
+ from safetensors.torch import save_file
11
+ from src.pipeline import FluxPipeline
12
+ from src.transformer_flux import FluxTransformer2DModel
13
+ from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
+
15
+ # Initialize the image processor
16
+ base_path = "black-forest-labs/FLUX.1-dev"
17
+ lora_base_path = "./models"
18
+
19
+ pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16)
20
+ transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.bfloat16)
21
+ pipe.transformer = transformer
22
+ # 移除 pipe.to("cuda"),默认使用CPU
23
+
24
+ def clear_cache(transformer):
25
+ for name, attn_processor in transformer.attn_processors.items():
26
+ attn_processor.bank_kv.clear()
27
+
28
+ # Define the Gradio interface
29
+ @spaces.GPU() # 改为 @spaces.CPU() 或直接移除,因为免费层没有GPU
30
+ def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
31
+ # Set the control type
32
+ if control_type == "Ghibli":
33
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
34
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512, device="cpu")
35
 
36
+ # Process the image
37
+ spatial_imgs = [spatial_img] if spatial_img else []
38
+ image = pipe(
39
+ prompt,
40
+ height=int(height),
41
+ width=int(width),
42
+ guidance_scale=3.5,
43
+ num_inference_steps=15, # 减少步数以适应CPU
44
+ max_sequence_length=512,
45
+ generator=torch.Generator("cpu").manual_seed(seed),
46
+ subject_images=[],
47
+ spatial_images=spatial_imgs,
48
+ cond_size=512,
49
+ ).images[0]
50
+ clear_cache(pipe.transformer)
51
  return image
52
 
53
+ # Define the Gradio interface components
54
+ control_types = ["Ghibli"]
55
+
56
+ # Example data (调整分辨率以适应CPU)
57
+ single_examples = [
58
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/00.png"), 512, 512, 5, "Ghibli"],
59
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/02.png"), 512, 512, 42, "Ghibli"],
60
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/03.png"), 512, 512, 1, "Ghibli"],
61
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/04.png"), 512, 512, 1, "Ghibli"],
62
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/06.png"), 512, 512, 1, "Ghibli"],
63
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/07.png"), 512, 512, 1, "Ghibli"],
64
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/08.png"), 512, 512, 1, "Ghibli"],
65
+ ["Ghibli Studio style, Charming hand-drawn anime-style illustration", Image.open("./test_imgs/09.png"), 512, 512, 1, "Ghibli"],
66
+ ]
67
+
68
+ # Create the Gradio Blocks interface
 
 
 
69
  with gr.Blocks() as demo:
70
  gr.Markdown("# Ghibli Studio Control Image Generation with EasyControl")
71
+ gr.Markdown("The model is trained on **only 100 real Asian faces** paired with **GPT-4o-generated Ghibli-style counterparts**, and it preserves facial features while applying the iconic anime aesthetic.")
72
+ gr.Markdown("Generate images using EasyControl with Ghibli control LoRAs.(Running on CPU due to free tier limitations; expect slower performance and lower resolution.)")
73
+
74
+ gr.Markdown("**[Attention!!]**:The recommended prompts for using Ghibli Control LoRA should include the trigger words: `Ghibli Studio style, Charming hand-drawn anime-style illustration`")
75
+ gr.Markdown("😊😊If you like this demo, please give us a star (github: [EasyControl](https://github.com/Xiaojiu-z/EasyControl))")
76
 
77
  with gr.Tab("Ghibli Condition Generation"):
78
  with gr.Row():
79
  with gr.Column():
80
  prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, Charming hand-drawn anime-style illustration")
81
  spatial_img = gr.Image(label="Ghibli Image", type="pil")
82
+ height = gr.Slider(minimum=256, maximum=512, step=64, label="Height", value=512) # 限制最大分辨率
83
+ width = gr.Slider(minimum=256, maximum=512, step=64, label="Width", value=512) # 限制最大分辨率
84
  seed = gr.Number(label="Seed", value=42)
85
+ control_type = gr.Dropdown(choices=control_types, label="Control Type")
86
+ single_generate_btn = gr.Button("Generate Image")
87
  with gr.Column():
88
+ single_output_image = gr.Image(label="Generated Image")
89
 
90
+ gr.Examples(
91
+ examples=single_examples,
92
  inputs=[prompt, spatial_img, height, width, seed, control_type],
93
+ outputs=single_output_image,
94
+ fn=single_condition_generate_image,
95
+ cache_examples=False,
96
+ label="Single Condition Examples"
97
  )
98
 
99
+ single_generate_btn.click(
100
+ single_condition_generate_image,
101
+ inputs=[prompt, spatial_img, height, width, seed, control_type],
102
+ outputs=single_output_image
103
+ )
 
 
 
 
 
 
 
 
 
 
104
 
105
+ # Launch the Gradio app
106
+ demo.queue().launch()