innoai commited on
Commit
76f3bac
·
verified ·
1 Parent(s): a4c5f25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -28
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  from PIL import Image
7
  from tqdm import tqdm
8
  import gradio as gr
 
9
 
10
  from safetensors.torch import save_file
11
  from src.pipeline import FluxPipeline
@@ -16,43 +17,91 @@ from src.lora_helper import set_single_lora, set_multi_lora, unset_lora
16
  base_path = "black-forest-labs/FLUX.1-dev"
17
  lora_base_path = "./models"
18
 
 
 
19
 
20
- pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.float32) # 使用 float32 而不是 bfloat16,以便在 CPU 上运行
21
- transformer = FluxTransformer2DModel.from_pretrained(base_path, subfolder="transformer", torch_dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
22
  pipe.transformer = transformer
23
- # 不再使用 cuda,在 CPU 环境运行
24
- # pipe.to("cuda")
25
 
26
  def clear_cache(transformer):
27
  for name, attn_processor in transformer.attn_processors.items():
28
  attn_processor.bank_kv.clear()
 
 
 
29
 
30
  # Define the Gradio interface
31
- @spaces.CPU() # 使用 CPU 装饰器替代 GPU 装饰器
32
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
33
- # Set the control type
34
- if control_type == "Ghibli":
35
- lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
36
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
37
-
38
- # Process the image
39
- spatial_imgs = [spatial_img] if spatial_img else []
40
-
41
- # 由于在 CPU 上运行,可能需要降低一些参数来提高性能
42
- image = pipe(
43
- prompt,
44
- height=int(height),
45
- width=int(width),
46
- guidance_scale=3.5,
47
- num_inference_steps=20, # 减少推理步骤以在 CPU 上更快运行
48
- max_sequence_length=512,
49
- generator=torch.Generator().manual_seed(seed), # 移除 "cpu" 参数,因为在 CPU 上默认就是 CPU 生成器
50
- subject_images=[],
51
- spatial_images=spatial_imgs,
52
- cond_size=512,
53
- ).images[0]
54
- clear_cache(pipe.transformer)
55
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # Define the Gradio interface components
58
  control_types = ["Ghibli"]
 
6
  from PIL import Image
7
  from tqdm import tqdm
8
  import gradio as gr
9
+ import gc # 导入垃圾回收模块
10
 
11
  from safetensors.torch import save_file
12
  from src.pipeline import FluxPipeline
 
17
  base_path = "black-forest-labs/FLUX.1-dev"
18
  lora_base_path = "./models"
19
 
20
+ # 设置更低的内存使用限制
21
+ torch.backends.cudnn.benchmark = False # 关闭 cudnn benchmark 以减少内存占用
22
 
23
+ # 使用较低精度和更保守的加载选项
24
+ pipe = FluxPipeline.from_pretrained(
25
+ base_path,
26
+ torch_dtype=torch.float32,
27
+ low_cpu_mem_usage=True, # 启用低内存使用模式
28
+ use_safetensors=True # 使用 safetensors 以减少内存使用
29
+ )
30
+ transformer = FluxTransformer2DModel.from_pretrained(
31
+ base_path,
32
+ subfolder="transformer",
33
+ torch_dtype=torch.float32,
34
+ low_cpu_mem_usage=True,
35
+ use_safetensors=True
36
+ )
37
  pipe.transformer = transformer
 
 
38
 
39
  def clear_cache(transformer):
40
  for name, attn_processor in transformer.attn_processors.items():
41
  attn_processor.bank_kv.clear()
42
+ # 手动触发垃圾回收
43
+ gc.collect()
44
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
45
 
46
  # Define the Gradio interface
47
+ @spaces.CPU()
48
  def single_condition_generate_image(prompt, spatial_img, height, width, seed, control_type):
49
+ try:
50
+ # 限制图像尺寸,减少内存使用
51
+ max_dimension = 512 # 设置最大尺寸限制
52
+ if int(height) > max_dimension or int(width) > max_dimension:
53
+ aspect_ratio = float(width) / float(height)
54
+ if aspect_ratio > 1:
55
+ width = max_dimension
56
+ height = int(max_dimension / aspect_ratio)
57
+ else:
58
+ height = max_dimension
59
+ width = int(max_dimension * aspect_ratio)
60
+
61
+ # Set the control type
62
+ if control_type == "Ghibli":
63
+ lora_path = os.path.join(lora_base_path, "Ghibli.safetensors")
64
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
65
+
66
+ # 如果有空间图像,确保其尺寸合理
67
+ if spatial_img:
68
+ # 调整空间图像尺寸以减少内存使用
69
+ max_img_size = 1024
70
+ if max(spatial_img.size) > max_img_size:
71
+ ratio = max_img_size / max(spatial_img.size)
72
+ new_size = (int(spatial_img.size[0] * ratio), int(spatial_img.size[1] * ratio))
73
+ spatial_img = spatial_img.resize(new_size, Image.LANCZOS)
74
+ spatial_imgs = [spatial_img]
75
+ else:
76
+ spatial_imgs = []
77
+
78
+ # 使用更保守的参数
79
+ image = pipe(
80
+ prompt,
81
+ height=int(height),
82
+ width=int(width),
83
+ guidance_scale=3.0, # 略微降低指导比例
84
+ num_inference_steps=15, # 进一步减少推理步骤
85
+ max_sequence_length=384, # 减少序列长度
86
+ generator=torch.Generator().manual_seed(seed),
87
+ subject_images=[],
88
+ spatial_images=spatial_imgs,
89
+ cond_size=384, # 减小条件尺寸
90
+ ).images[0]
91
+
92
+ # 清理缓存并回收内存
93
+ clear_cache(pipe.transformer)
94
+ gc.collect()
95
+
96
+ return image
97
+ except Exception as e:
98
+ # 处理错误并清理内存
99
+ clear_cache(pipe.transformer)
100
+ gc.collect()
101
+ print(f"Error during image generation: {str(e)}")
102
+ # 返回一个错误图像或消息
103
+ error_img = Image.new('RGB', (400, 200), color=(255, 255, 255))
104
+ return error_img
105
 
106
  # Define the Gradio interface components
107
  control_types = ["Ghibli"]