nowsyn commited on
Commit
ef7b543
1 Parent(s): 42a50a5

update app.py

Browse files
Files changed (1) hide show
  1. app.py +266 -4
app.py CHANGED
@@ -1,7 +1,269 @@
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+ sys.path.append(os.path.join(os.getcwd(), "annotator/entityseg"))
5
+ import cv2
6
+ import spaces
7
+ import einops
8
+ import torch
9
  import gradio as gr
10
+ import numpy as np
11
+ from pytorch_lightning import seed_everything
12
+ from PIL import Image
13
 
14
+ from annotator.util import resize_image, HWC3
15
+ from annotator.canny import CannyDetector
16
+ from annotator.midas import MidasDetector
17
+ from annotator.entityseg import EntitysegDetector
18
+ from annotator.openpose import OpenposeDetector
19
+ from annotator.content import ContentDetector
20
+ from annotator.cielab import CIELabDetector
21
 
22
+ from models.util import create_model, load_state_dict
23
+ from models.ddim_hacked import DDIMSampler
24
+
25
+ '''
26
+ define conditions
27
+ '''
28
+ max_conditions = 8
29
+ condition_types = ["edge", "depth", "seg", "pose", "content", "color"]
30
+
31
+ apply_canny = CannyDetector()
32
+ apply_midas = MidasDetector()
33
+ apply_seg = EntitysegDetector()
34
+ apply_openpose = OpenposeDetector()
35
+ apply_content = ContentDetector()
36
+ apply_color = CIELabDetector()
37
+
38
+ processors = {
39
+ "edge": apply_canny,
40
+ "depth": apply_midas,
41
+ "seg": apply_seg,
42
+ "pose": apply_openpose,
43
+ "content": apply_content,
44
+ "color": apply_color,
45
+ }
46
+
47
+ descriptors = {
48
+ "edge": "canny",
49
+ "depth": "depth",
50
+ "seg": "segmentation",
51
+ "pose": "openpose",
52
+ }
53
+
54
+
55
+ @torch.no_grad()
56
+ def get_unconditional_global(c_global):
57
+ if isinstance(c_global, dict):
58
+ return {k:torch.zeros_like(v) for k,v in c_global.items()}
59
+ elif isinstance(c_global, list):
60
+ return [torch.zeros_like(c) for c in c_global]
61
+ else:
62
+ return torch.zeros_like(c_global)
63
+
64
+
65
+ @spaces.GPU
66
+ def process(prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps,
67
+ strength, scale, seed, eta, global_strength, color_strength, local_strength, *args):
68
+
69
+ seed_everything(seed)
70
+
71
+ conds_and_types = args
72
+ conds = conds_and_types[0::2]
73
+ types = conds_and_types[1::2]
74
+ conds = [c for c in conds if c is not None]
75
+ types = [t for t in types if t is not None]
76
+ assert len(conds) == len(types)
77
+
78
+ detected_maps = []
79
+ other_maps = []
80
+ tasks = []
81
+
82
+ # initialize global control
83
+ global_conditions = dict(clipembedding=np.zeros((1, 768), dtype=np.float32), color=np.zeros((1, 180), dtype=np.float32))
84
+ global_control = {}
85
+ for key in global_conditions.keys():
86
+ global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1)
87
+ global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float()
88
+ global_control[key] = global_cond
89
+
90
+ # initialize local control
91
+ anchor_image = HWC3(np.zeros((image_resolution, image_resolution, 3)).astype(np.uint8))
92
+ oH, oW = anchor_image.shape[:2]
93
+ H, W, C = resize_image(anchor_image, image_resolution).shape
94
+ anchor_tensor = ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(anchor_image))
95
+ local_control = torch.tensor(anchor_tensor).cuda().to(memory_format=torch.contiguous_format).half()
96
+
97
+ task_prompt = ''
98
+
99
+ with torch.no_grad():
100
+
101
+ # set up local control
102
+ for cond, typ in zip(conds, types):
103
+ if typ in ['edge', 'depth', 'seg', 'pose']:
104
+ oH, oW = cond.shape[:2]
105
+ cond_image = HWC3(cv2.resize(cond, (W, H)))
106
+ cond_detected_map = processors[typ](cond_image)
107
+ cond_detected_map = HWC3(cond_detected_map)
108
+ detected_maps.append(cond_detected_map)
109
+ tasks.append(descriptors[typ])
110
+ elif typ in ['content']:
111
+ other_maps.append(cond)
112
+ content_image = cv2.cvtColor(cond, cv2.COLOR_RGB2BGR)
113
+ content_emb = apply_content(content_image)
114
+ global_conditions['clipembedding'] = content_emb
115
+ elif typ in ['color']:
116
+ color_hist = apply_color(cond)
117
+ global_conditions['color'] = color_hist
118
+ color_palette = apply_color.hist_to_palette(color_hist) # (50, 189, 3)
119
+ color_palette = cv2.resize(color_palette, (W, H), cv2.INTER_NEAREST)
120
+ other_maps.append(color_palette)
121
+ if len(detected_maps) > 0:
122
+ local_control = torch.cat([ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(img)).cuda().unsqueeze(0) for img in detected_maps], dim=1)
123
+ task_prompt = ' conditioned on ' + ' and '.join(tasks)
124
+ local_control = local_control.repeat(num_samples, 1, 1, 1)
125
+
126
+ # set up global control
127
+ for key in global_conditions.keys():
128
+ global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1)
129
+ global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float()
130
+ global_control[key] = global_cond
131
+
132
+ # set up prompt
133
+ input_prompt = (prompt + ' ' + task_prompt).strip()
134
+
135
+ # set up cfg
136
+ uc_local_control = local_control
137
+ uc_global_control = get_unconditional_global(global_control)
138
+ cond = {
139
+ "local_control": [local_control],
140
+ "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
141
+ "global_control": [global_control],
142
+ "text": [[input_prompt] * num_samples],
143
+ }
144
+ un_cond = {
145
+ "local_control": [uc_local_control],
146
+ "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)],
147
+ 'global_control': [uc_global_control],
148
+ "text": [[input_prompt] * num_samples],
149
+ }
150
+ shape = (4, H // 8, W // 8)
151
+
152
+ model.control_scales = [strength] * 13
153
+ samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
154
+ shape, cond, verbose=False, eta=eta,
155
+ unconditional_guidance_scale=scale,
156
+ unconditional_conditioning=un_cond,
157
+ global_strength=global_strength,
158
+ color_strength=color_strength,
159
+ local_strength=local_strength)
160
+
161
+ x_samples = model.decode_first_stage(samples)
162
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
163
+ results = [x_samples[i] for i in range(num_samples)]
164
+
165
+ results = [cv2.resize(res, (oW, oH)) for res in results]
166
+ detected_maps = [cv2.resize(maps, (oW, oH)) for maps in detected_maps]
167
+ return [results, detected_maps+other_maps]
168
+
169
+
170
+ def variable_image_outputs(k):
171
+ if k is None:
172
+ k = 1
173
+ k = int(k)
174
+ imageboxes = []
175
+ for i in range(max_conditions):
176
+ if i<k:
177
+ with gr.Row(visible=True):
178
+ img = gr.Image(sources=['upload'], type="numpy", label=f'Condition {i+1}', visible=True, interactive=True, scale=3, height=200)
179
+ typ = gr.Dropdown(condition_types, visible=True, interactive=True, label="type", scale=1)
180
+ else:
181
+ with gr.Row(visible=False):
182
+ img = gr.Image(sources=['upload'], type="numpy", label=f'Condition {i+1}', visible=False, scale=3, height=200)
183
+ typ = gr.Dropdown(condition_types, visible=False, interactive=True, label="type", scale=1)
184
+ imageboxes.append(img)
185
+ imageboxes.append(typ)
186
+ return imageboxes
187
+
188
+
189
+ '''
190
+ define model
191
+ '''
192
+ config_file = "configs/anycontrol.yaml"
193
+ model_file = "ckpts/anycontrol_15.ckpt"
194
+ model = create_model(config_file).cpu()
195
+ model.load_state_dict(load_state_dict(model_file, location='cuda'))
196
+ model = model.cuda()
197
+ ddim_sampler = DDIMSampler(model)
198
+
199
+
200
+
201
+ block = gr.Blocks(theme='bethecloud/storj_theme').queue()
202
+ with block:
203
+ with gr.Row():
204
+ gr.Markdown("## AnyControl Demo")
205
+ gr.Markdown("---")
206
+ with gr.Row():
207
+ with gr.Column(scale=1):
208
+ with gr.Blocks():
209
+ s = gr.Slider(1, max_conditions, value=1, step=1, label="How many conditions to upload:")
210
+ imageboxes = []
211
+ for i in range(max_conditions):
212
+ if i==0:
213
+ with gr.Row():
214
+ img = gr.Image(visible=True, sources=['upload'], type="numpy", label='Condition 1', interactive=True, scale=3, height=200)
215
+ typ = gr.Dropdown(condition_types, visible=True, interactive=True, label="type", scale=1)
216
+ else:
217
+ with gr.Row():
218
+ img = gr.Image(visible=False, sources=['upload'], type="numpy", label=f'Condition {i+1}', scale=3, height=200)
219
+ typ = gr.Dropdown(condition_types, visible=False, interactive=True, label="type", scale=1)
220
+ imageboxes.append(img)
221
+ imageboxes.append(typ)
222
+ s.change(variable_image_outputs, s, imageboxes)
223
+ with gr.Column(scale=2):
224
+ with gr.Row():
225
+ prompt = gr.Textbox(label="Prompt")
226
+ with gr.Row():
227
+ with gr.Column():
228
+ with gr.Accordion("Advanced options", open=False):
229
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
230
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
231
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1, step=0.01)
232
+
233
+ local_strength = gr.Slider(label="Local Strength", minimum=0, maximum=2, value=1, step=0.01)
234
+ global_strength = gr.Slider(label="Global Strength", minimum=0, maximum=2, value=1, step=0.01)
235
+ color_strength = gr.Slider(label="Color Strength", minimum=0, maximum=2, value=1, step=0.01)
236
+
237
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
238
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
239
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
240
+ eta = gr.Number(label="Eta (DDIM)", value=0.0)
241
+
242
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
243
+ n_prompt = gr.Textbox(label="Negative Prompt",
244
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
245
+
246
+
247
+ with gr.Row():
248
+ run_button = gr.Button(value="Run")
249
+ with gr.Row():
250
+ image_gallery = gr.Gallery(label='Generation', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto', interactive=False)
251
+ with gr.Row():
252
+ cond_gallery = gr.Gallery(label='Condition', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto', interactive=False)
253
+
254
+ inputs = [prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps,
255
+ strength, scale, seed, eta, local_strength, global_strength, color_strength] + imageboxes
256
+ run_button.click(fn=process, inputs=inputs, outputs=[image_gallery, cond_gallery])
257
+
258
+
259
+ # uncomment this block in case you need it
260
+ # os.environ['http_proxy'] = ''
261
+ # os.environ['https_proxy'] = ''
262
+ # os.environ['no_proxy'] = 'localhost,127.0.0.0/8,127.0.1.1'
263
+ # os.environ['HTTP_PROXY'] = ''
264
+ # os.environ['HTTPS_PROXY'] = ''
265
+ # os.environ['NO_PROXY'] = 'localhost,127.0.0.0/8,127.0.1.1'
266
+ # os.environ['TMPDIR'] = './tmpfiles'
267
+
268
+
269
+ block.launch(server_name='0.0.0.0', allowed_paths=["."], share=False)