Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ import torch
|
|
6 |
from PIL import Image
|
7 |
from tqdm import tqdm
|
8 |
import gradio as gr
|
|
|
9 |
|
10 |
from safetensors.torch import save_file
|
11 |
from src.pipeline import FluxPipeline
|
@@ -16,43 +17,91 @@ from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
|
|
16 |
base_path = "black-forest-labs/FLUX.1-dev"
|
17 |
lora_base_path = "./models"
|
18 |
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
pipe.transformer = transformer
|
23 |
-
# 不再使用 cuda,在 CPU 环境运行
|
24 |
-
# pipe.to("cuda")
|
25 |
|
26 |
def clear_cache(transformer):
|
27 |
for name, attn_processor in transformer.attn_processors.items():
|
28 |
attn_processor.bank_kv.clear()
|
|
|
|
|
|
|
29 |
|
30 |
# Define the Gradio interface
|
31 |
-
@spaces.CPU()
|
32 |
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
# Define the Gradio interface components
|
58 |
control_types = ["Ghibli"]
|
|
|
6 |
from PIL import Image
|
7 |
from tqdm import tqdm
|
8 |
import gradio as gr
|
9 |
+
import gc # 导入垃圾回收模块
|
10 |
|
11 |
from safetensors.torch import save_file
|
12 |
from src.pipeline import FluxPipeline
|
|
|
17 |
base_path = "black-forest-labs/FLUX.1-dev"
|
18 |
lora_base_path = "./models"
|
19 |
|
20 |
+
# 设置更低的内存使用限制
|
21 |
+
torch.backends.cudnn.benchmark = False # 关闭 cudnn benchmark 以减少内存占用
|
22 |
|
23 |
+
# 使用较低精度和更保守的加载选项
|
24 |
+
pipe = FluxPipeline.from_pretrained(
|
25 |
+
base_path,
|
26 |
+
torch_dtype=torch.float32,
|
27 |
+
low_cpu_mem_usage=True, # 启用低内存使用模式
|
28 |
+
use_safetensors=True # 使用 safetensors 以减少内存使用
|
29 |
+
)
|
30 |
+
transformer = FluxTransformer2DModel.from_pretrained(
|
31 |
+
base_path,
|
32 |
+
subfolder="transformer",
|
33 |
+
torch_dtype=torch.float32,
|
34 |
+
low_cpu_mem_usage=True,
|
35 |
+
use_safetensors=True
|
36 |
+
)
|
37 |
pipe.transformer = transformer
|
|
|
|
|
38 |
|
39 |
def clear_cache(transformer):
|
40 |
for name, attn_processor in transformer.attn_processors.items():
|
41 |
attn_processor.bank_kv.clear()
|
42 |
+
# 手动触发垃圾回收
|
43 |
+
gc.collect()
|
44 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
45 |
|
46 |
# Define the Gradio interface
|
47 |
+
@spaces.CPU()
|
48 |
def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
|
49 |
+
try:
|
50 |
+
# 限制图像尺寸,减少内存使用
|
51 |
+
max_dimension = 512 # 设置最大尺寸限制
|
52 |
+
if int(height) > max_dimension or int(width) > max_dimension:
|
53 |
+
aspect_ratio = float(width) / float(height)
|
54 |
+
if aspect_ratio > 1:
|
55 |
+
width = max_dimension
|
56 |
+
height = int(max_dimension / aspect_ratio)
|
57 |
+
else:
|
58 |
+
height = max_dimension
|
59 |
+
width = int(max_dimension * aspect_ratio)
|
60 |
+
|
61 |
+
# Set the control type
|
62 |
+
if control_type == "Ghibli":
|
63 |
+
lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
|
64 |
+
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
|
65 |
+
|
66 |
+
# 如果有空间图像,确保其尺寸合理
|
67 |
+
if spatial_img:
|
68 |
+
# 调整空间图像尺寸以减少内存使用
|
69 |
+
max_img_size = 1024
|
70 |
+
if max(spatial_img.size) > max_img_size:
|
71 |
+
ratio = max_img_size / max(spatial_img.size)
|
72 |
+
new_size = (int(spatial_img.size[0] * ratio), int(spatial_img.size[1] * ratio))
|
73 |
+
spatial_img = spatial_img.resize(new_size, Image.LANCZOS)
|
74 |
+
spatial_imgs = [spatial_img]
|
75 |
+
else:
|
76 |
+
spatial_imgs = []
|
77 |
+
|
78 |
+
# 使用更保守的参数
|
79 |
+
image = pipe(
|
80 |
+
prompt,
|
81 |
+
height=int(height),
|
82 |
+
width=int(width),
|
83 |
+
guidance_scale=3.0, # 略微降低指导比例
|
84 |
+
num_inference_steps=15, # 进一步减少推理步骤
|
85 |
+
max_sequence_length=384, # 减少序列长度
|
86 |
+
generator=torch.Generator().manual_seed(seed),
|
87 |
+
subject_images=[],
|
88 |
+
spatial_images=spatial_imgs,
|
89 |
+
cond_size=384, # 减小条件尺寸
|
90 |
+
).images[0]
|
91 |
+
|
92 |
+
# 清理缓存并回收内存
|
93 |
+
clear_cache(pipe.transformer)
|
94 |
+
gc.collect()
|
95 |
+
|
96 |
+
return image
|
97 |
+
except Exception as e:
|
98 |
+
# 处理错误并清理内存
|
99 |
+
clear_cache(pipe.transformer)
|
100 |
+
gc.collect()
|
101 |
+
print(f"Error during image generation: {str(e)}")
|
102 |
+
# 返回一个错误图像或消息
|
103 |
+
error_img = Image.new('RGB', (400, 200), color=(255, 255, 255))
|
104 |
+
return error_img
|
105 |
|
106 |
# Define the Gradio interface components
|
107 |
control_types = ["Ghibli"]
|