zhiweili commited on
Commit
91bb531
·
1 Parent(s): 304cdbb

test refiner

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_haircolor_refiner.py +124 -0
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_haircolor_inpaint_15 import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_haircolor_refiner import create_demo as create_demo_haircolor
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_haircolor_refiner.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+ import numpy as np
6
+
7
+ from PIL import Image
8
+ from segment_utils import(
9
+ segment_image,
10
+ restore_result,
11
+ )
12
+ from diffusers import (
13
+ DiffusionPipeline,
14
+ )
15
+
16
+
17
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-refiner-1.0"
18
+
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ DEFAULT_EDIT_PROMPT = "blue hair"
22
+
23
+ DEFAULT_CATEGORY = "hair"
24
+
25
+ basepipeline = DiffusionPipeline.from_pretrained(
26
+ BASE_MODEL,
27
+ torch_dtype=torch.float16,
28
+ use_safetensors=True,
29
+ variant="fp16",
30
+ )
31
+
32
+ basepipeline = basepipeline.to(DEVICE)
33
+
34
+ @spaces.GPU(duration=30)
35
+ def image_to_image(
36
+ input_image: Image,
37
+ edit_prompt: str,
38
+ seed: int,
39
+ num_steps: int,
40
+ guidance_scale: float,
41
+ generate_size: int,
42
+ ):
43
+ run_task_time = 0
44
+ time_cost_str = ''
45
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
46
+
47
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
48
+ generated_image = basepipeline(
49
+ generator=generator,
50
+ prompt=edit_prompt,
51
+ image=input_image,
52
+ # denoising_start=denoising_start,
53
+ guidance_scale=guidance_scale,
54
+ num_inference_steps=num_steps,
55
+ ).images[0]
56
+
57
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
58
+
59
+ return generated_image, time_cost_str
60
+
61
+ def make_inpaint_condition(image, image_mask):
62
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
63
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
64
+
65
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
66
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
67
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
68
+ image = torch.from_numpy(image)
69
+ return image
70
+
71
+ def get_time_cost(run_task_time, time_cost_str):
72
+ now_time = int(time.time()*1000)
73
+ if run_task_time == 0:
74
+ time_cost_str = 'start'
75
+ else:
76
+ if time_cost_str != '':
77
+ time_cost_str += f'-->'
78
+ time_cost_str += f'{now_time - run_task_time}'
79
+ run_task_time = now_time
80
+ return run_task_time, time_cost_str
81
+
82
+ def create_demo() -> gr.Blocks:
83
+ with gr.Blocks() as demo:
84
+ croper = gr.State()
85
+ with gr.Row():
86
+ with gr.Column():
87
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
88
+ generate_size = gr.Number(label="Generate Size", value=512)
89
+ with gr.Column():
90
+ num_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Steps")
91
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=5, step=0.5, label="Guidance Scale")
92
+ with gr.Column():
93
+ with gr.Accordion("Advanced Options", open=False):
94
+ mask_expansion = gr.Number(label="Mask Expansion", value=50, visible=True)
95
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
96
+ seed = gr.Number(label="Seed", value=8)
97
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
98
+ g_btn = gr.Button("Edit Image")
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ input_image = gr.Image(label="Input Image", type="pil")
103
+ with gr.Column():
104
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
105
+ with gr.Column():
106
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
107
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
108
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
109
+
110
+ g_btn.click(
111
+ fn=segment_image,
112
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
113
+ outputs=[origin_area_image, croper],
114
+ ).success(
115
+ fn=image_to_image,
116
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale, generate_size],
117
+ outputs=[generated_image, generated_cost],
118
+ ).success(
119
+ fn=restore_result,
120
+ inputs=[croper, category, generated_image],
121
+ outputs=[restored_image],
122
+ )
123
+
124
+ return demo