zhiweili commited on
Commit
b38f27e
1 Parent(s): 0ff5d61

add upscale

Browse files
Files changed (2) hide show
  1. app_base.py +31 -10
  2. upscale.py +25 -0
app_base.py CHANGED
@@ -8,7 +8,7 @@ from segment_utils import(
8
  segment_image,
9
  restore_result,
10
  )
11
- from enhance_utils import enhance_image
12
 
13
  DEFAULT_SRC_PROMPT = "a person"
14
  DEFAULT_EDIT_PROMPT = "a person with perfect face"
@@ -31,15 +31,24 @@ def create_demo() -> gr.Blocks:
31
  start_step: int,
32
  guidance_scale: float,
33
  generate_size: int,
34
- pre_enhance: bool = True,
35
- pre_enhance_scale: int = 2,
 
 
 
 
36
  ):
37
  w2 = 1.0
38
  run_task_time = 0
39
  time_cost_str = ''
40
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
41
- if pre_enhance:
42
- input_image = enhance_image(input_image, enhance_face=True, scale=pre_enhance_scale)
 
 
 
 
 
43
  input_image = input_image.resize((generate_size, generate_size))
44
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
45
  run_model = base_run
@@ -56,7 +65,12 @@ def create_demo() -> gr.Blocks:
56
  guidance_scale,
57
  )
58
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
59
- enhanced_image = enhance_image(res_image)
 
 
 
 
 
60
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
61
 
62
  return enhanced_image, res_image, time_cost_str
@@ -79,16 +93,21 @@ def create_demo() -> gr.Blocks:
79
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
80
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
81
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
 
 
 
 
 
 
 
82
  with gr.Column():
83
  num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
84
  start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
85
  with gr.Accordion("Advanced Options", open=False):
86
  guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
87
- generate_size = gr.Number(label="Generate Size", value=512)
88
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
89
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
90
- pre_enhance = gr.Checkbox(label="Pre Enhance", value=True)
91
- pre_enhance_scale = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Pre Enhance Scale")
92
  with gr.Column():
93
  seed = gr.Number(label="Seed", value=8)
94
  w1 = gr.Number(label="W1", value=1.5)
@@ -112,7 +131,9 @@ def create_demo() -> gr.Blocks:
112
  outputs=[origin_area_image, croper],
113
  ).success(
114
  fn=image_to_image,
115
- inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step, guidance_scale, generate_size, pre_enhance, pre_enhance_scale],
 
 
116
  outputs=[enhanced_image, generated_image, generated_cost],
117
  ).success(
118
  fn=restore_result,
 
8
  segment_image,
9
  restore_result,
10
  )
11
+ from upscale import upscale_image
12
 
13
  DEFAULT_SRC_PROMPT = "a person"
14
  DEFAULT_EDIT_PROMPT = "a person with perfect face"
 
31
  start_step: int,
32
  guidance_scale: float,
33
  generate_size: int,
34
+ upscale_prompt: str,
35
+ upscale_start_size: int = 256,
36
+ upscale_steps: int = 10,
37
+ pre_upscale: bool = True,
38
+ pre_upscale_start_size: int = 128,
39
+ pre_upscale_steps: int = 30,
40
  ):
41
  w2 = 1.0
42
  run_task_time = 0
43
  time_cost_str = ''
44
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
45
+ if pre_upscale:
46
+ input_image = upscale_image(
47
+ input_image,
48
+ upscale_prompt,
49
+ start_size=pre_upscale_start_size,
50
+ upscale_steps=pre_upscale_steps,
51
+ )
52
  input_image = input_image.resize((generate_size, generate_size))
53
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
54
  run_model = base_run
 
65
  guidance_scale,
66
  )
67
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
68
+ enhanced_image = upscale_image(
69
+ res_image,
70
+ upscale_prompt,
71
+ start_size=upscale_start_size,
72
+ upscale_steps=upscale_steps,
73
+ )
74
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
75
 
76
  return enhanced_image, res_image, time_cost_str
 
93
  input_image_prompt = gr.Textbox(lines=1, label="Input Image Prompt", value=DEFAULT_SRC_PROMPT)
94
  edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
95
  category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
96
+ with gr.Accordion("Advanced Options", open=False):
97
+ upscale_prompt = gr.Textbox(lines=1, label="Upscale Prompt", value="a person with pefect face")
98
+ upscale_start_size = gr.Number(label="Upscale Start Size", value=256)
99
+ upscale_steps = gr.Number(label="Upscale Steps", value=10)
100
+ pre_upscale = gr.Checkbox(label="Pre Upscale", value=True)
101
+ pre_upscale_start_size = gr.Number(label="Pre Upscale Start Size", value=128)
102
+ pre_upscale_steps = gr.Number(label="Pre Upscale Steps", value=30)
103
  with gr.Column():
104
  num_steps = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Num Steps")
105
  start_step = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Start Step")
106
  with gr.Accordion("Advanced Options", open=False):
107
  guidance_scale = gr.Slider(minimum=0, maximum=20, value=0, step=0.5, label="Guidance Scale")
108
+ generate_size = gr.Number(label="Generate Size", value=256)
109
  mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
110
  mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
 
 
111
  with gr.Column():
112
  seed = gr.Number(label="Seed", value=8)
113
  w1 = gr.Number(label="W1", value=1.5)
 
131
  outputs=[origin_area_image, croper],
132
  ).success(
133
  fn=image_to_image,
134
+ inputs=[origin_area_image, input_image_prompt, edit_prompt,seed,w1, num_steps, start_step,
135
+ guidance_scale, generate_size, upscale_prompt, upscale_start_size, upscale_steps,
136
+ pre_upscale, pre_upscale_start_size, pre_upscale_steps],
137
  outputs=[enhanced_image, generated_image, generated_cost],
138
  ).success(
139
  fn=restore_result,
upscale.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from PIL import Image
4
+ from diffusers import StableDiffusionUpscalePipeline
5
+
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
9
+ upscale_pipe = StableDiffusionUpscalePipeline.from_pretrained(model_id, torch_dtype=torch.float16)
10
+ upscale_pipe = upscale_pipe.to(device)
11
+
12
+ def upscale_image(
13
+ input_image: Image,
14
+ prompt: str,
15
+ start_size: int = 128,
16
+ upscale_steps: int = 30,
17
+ ):
18
+ input_image = input_image.resize((start_size, start_size))
19
+ upscaled_image = upscale_pipe(
20
+ prompt=prompt,
21
+ image=input_image,
22
+ num_inference_steps=upscale_steps,
23
+ ).images[0]
24
+
25
+ return upscaled_image