svjack commited on
Commit
738df7c
ยท
verified ยท
1 Parent(s): 54ac3f5

Create swtich_app_multi_output.py

Browse files
Files changed (1) hide show
  1. swtich_app_multi_output.py +444 -0
swtich_app_multi_output.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
+ from diffusers.utils import load_image
9
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from kolors.models.controlnet import ControlNetModel
13
+ from diffusers import AutoencoderKL
14
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers import EulerDiscreteScheduler
16
+ from PIL import Image
17
+ from annotator.midas import MidasDetector
18
+ from annotator.dwpose import DWposeDetector
19
+ from annotator.util import resize_image, HWC3
20
+
21
+ device = "cuda"
22
+
23
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
24
+ ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
25
+ ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny")
26
+ ckpt_dir_ipa = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
27
+ ckpt_dir_pose = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Pose")
28
+ '''
29
+ ckpt_dir = "Kolors"
30
+ ckpt_dir_depth = "Kolors-ControlNet-Depth"
31
+ ckpt_dir_canny = "Kolors-ControlNet-Canny"
32
+ ckpt_dir_ipa = "Kolors-IP-Adapter-Plus"
33
+ ckpt_dir_pose = "Kolors-ControlNet-Pose"
34
+ '''
35
+
36
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
37
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
38
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
39
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
40
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
41
+
42
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_ipa}/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
43
+ ip_img_size = 336
44
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
45
+
46
+ model_midas = MidasDetector()
47
+ model_dwpose = DWposeDetector()
48
+
49
+ MAX_SEED = np.iinfo(np.int32).max
50
+ MAX_IMAGE_SIZE = 512
51
+
52
+ def process_canny_condition(image, canny_threods=[100, 200]):
53
+ np_image = image.copy()
54
+ np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
55
+ np_image = np_image[:, :, None]
56
+ np_image = np.concatenate([np_image, np_image, np_image], axis=2)
57
+ np_image = HWC3(np_image)
58
+ return Image.fromarray(np_image)
59
+
60
+ def process_depth_condition_midas(img, res=1024):
61
+ h, w, _ = img.shape
62
+ img = resize_image(HWC3(img), res)
63
+ result = HWC3(model_midas(img))
64
+ result = cv2.resize(result, (w, h))
65
+ return Image.fromarray(result)
66
+
67
+ def process_dwpose_condition(image, res=1024):
68
+ h, w, _ = image.shape
69
+ img = resize_image(HWC3(image), res)
70
+ out_res, out_img = model_dwpose(image)
71
+ result = HWC3(out_img)
72
+ result = cv2.resize(result, (w, h))
73
+ return Image.fromarray(result)
74
+
75
+ def infer_canny(prompt,
76
+ image=None,
77
+ ipa_img=None,
78
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ",
79
+ seed=66,
80
+ randomize_seed=False,
81
+ guidance_scale=5.0,
82
+ num_inference_steps=50,
83
+ controlnet_conditioning_scale=0.5,
84
+ control_guidance_end=0.9,
85
+ strength=1.0,
86
+ ip_scale=0.5,
87
+ num_images=1):
88
+ if randomize_seed:
89
+ seed = random.randint(0, MAX_SEED)
90
+ #generator = torch.Generator().manual_seed(seed)
91
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
92
+ pipe = pipe_canny.to("cuda")
93
+ pipe.set_ip_adapter_scale([ip_scale])
94
+ condi_img = process_canny_condition(np.array(init_image))
95
+ images = []
96
+ for i in range(num_images):
97
+ generator = torch.Generator().manual_seed(seed + i)
98
+ image = pipe(
99
+ prompt=prompt,
100
+ image=init_image,
101
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
102
+ control_guidance_end=control_guidance_end,
103
+ ip_adapter_image=[ipa_img],
104
+ strength=strength,
105
+ control_image=condi_img,
106
+ negative_prompt=negative_prompt,
107
+ num_inference_steps=num_inference_steps,
108
+ guidance_scale=guidance_scale,
109
+ num_images_per_prompt=1,
110
+ generator=generator,
111
+ ).images[0]
112
+ images.append(image)
113
+ return [condi_img] + images, seed
114
+
115
+ def infer_depth(prompt,
116
+ image=None,
117
+ ipa_img=None,
118
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ฟฝ๏ฟฝ้œ“่™น็ฏ",
119
+ seed=66,
120
+ randomize_seed=False,
121
+ guidance_scale=5.0,
122
+ num_inference_steps=50,
123
+ controlnet_conditioning_scale=0.5,
124
+ control_guidance_end=0.9,
125
+ strength=1.0,
126
+ ip_scale=0.5,
127
+ num_images=1):
128
+ if randomize_seed:
129
+ seed = random.randint(0, MAX_SEED)
130
+ #generator = torch.Generator().manual_seed(seed)
131
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
132
+ pipe = pipe_depth.to("cuda")
133
+ pipe.set_ip_adapter_scale([ip_scale])
134
+ condi_img = process_depth_condition_midas(np.array(init_image), MAX_IMAGE_SIZE)
135
+ images = []
136
+ for i in range(num_images):
137
+ generator = torch.Generator().manual_seed(seed + i)
138
+ image = pipe(
139
+ prompt=prompt,
140
+ image=init_image,
141
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
142
+ control_guidance_end=control_guidance_end,
143
+ ip_adapter_image=[ipa_img],
144
+ strength=strength,
145
+ control_image=condi_img,
146
+ negative_prompt=negative_prompt,
147
+ num_inference_steps=num_inference_steps,
148
+ guidance_scale=guidance_scale,
149
+ num_images_per_prompt=1,
150
+ generator=generator,
151
+ ).images[0]
152
+ images.append(image)
153
+ return [condi_img] + images, seed
154
+
155
+ def infer_pose(prompt,
156
+ image=None,
157
+ ipa_img=None,
158
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒjpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ",
159
+ seed=66,
160
+ randomize_seed=False,
161
+ guidance_scale=5.0,
162
+ num_inference_steps=50,
163
+ controlnet_conditioning_scale=0.5,
164
+ control_guidance_end=0.9,
165
+ strength=1.0,
166
+ ip_scale=0.5,
167
+ num_images=1):
168
+ if randomize_seed:
169
+ seed = random.randint(0, MAX_SEED)
170
+ #generator = torch.Generator().manual_seed(seed)
171
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
172
+ pipe = pipe_pose.to("cuda")
173
+ pipe.set_ip_adapter_scale([ip_scale])
174
+ condi_img = process_dwpose_condition(np.array(init_image), MAX_IMAGE_SIZE)
175
+ images = []
176
+ for i in range(num_images):
177
+ generator = torch.Generator().manual_seed(seed + i)
178
+ image = pipe(
179
+ prompt=prompt,
180
+ image=init_image,
181
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
182
+ control_guidance_end=control_guidance_end,
183
+ ip_adapter_image=[ipa_img],
184
+ strength=strength,
185
+ control_image=condi_img,
186
+ negative_prompt=negative_prompt,
187
+ num_inference_steps=num_inference_steps,
188
+ guidance_scale=guidance_scale,
189
+ num_images_per_prompt=1,
190
+ generator=generator,
191
+ ).images[0]
192
+ images.append(image)
193
+ return [condi_img] + images, seed
194
+
195
+ canny_examples = [
196
+ ["ไธ€ไธช็บข่‰ฒๅคดๅ‘็š„ๅฅณๅญฉ๏ผŒๅ”ฏ็พŽ้ฃŽๆ™ฏ๏ผŒๆธ…ๆ–ฐๆ˜Žไบฎ๏ผŒๆ–‘้ฉณ็š„ๅ…‰ๅฝฑ๏ผŒๆœ€ๅฅฝ็š„่ดจ้‡๏ผŒ่ถ…็ป†่Š‚๏ผŒ8K็”ป่ดจ",
197
+ "image/woman_2.png", "image/2.png", 3],
198
+ ]
199
+
200
+ depth_examples = [
201
+ ["ไธ€ไธชๆผ‚ไบฎ็š„ๅฅณๅญฉ๏ผŒๆœ€ๅฅฝ็š„่ดจ้‡๏ผŒ่ถ…็ป†่Š‚๏ผŒ8K็”ป่ดจ",
202
+ "image/1.png", "image/woman_1.png", 3],
203
+ ]
204
+
205
+ pose_examples = [
206
+ ["ไธ€ไฝ็ฉฟ็€็ดซ่‰ฒๆณกๆณก่ข–่ฟž่กฃ่ฃ™ใ€ๆˆด็€็š‡ๅ† ๅ’Œ็™ฝ่‰ฒ่•พไธๆ‰‹ๅฅ—็š„ๅฅณๅญฉ๏ผŒ่ถ…้ซ˜ๅˆ†่พจ็Ž‡๏ผŒๆœ€ไฝณๅ“่ดจ๏ผŒ8k็”ป่ดจ",
207
+ "image/woman_3.png", "image/woman_4.png", 3],
208
+ ]
209
+
210
+ css = """
211
+ #col-left {
212
+ margin: 0 auto;
213
+ max-width: 600px;
214
+ }
215
+ #col-right {
216
+ margin: 0 auto;
217
+ max-width: 750px;
218
+ }
219
+ #button {
220
+ color: blue;
221
+ }
222
+ """
223
+
224
+ def load_description(fp):
225
+ with open(fp, 'r', encoding='utf-8') as f:
226
+ content = f.read()
227
+ return content
228
+
229
+ def clear_resources():
230
+ global pipe_canny, pipe_depth, pipe_pose
231
+ if 'pipe_canny' in globals():
232
+ del pipe_canny
233
+ if 'pipe_depth' in globals():
234
+ del pipe_depth
235
+ if 'pipe_pose' in globals():
236
+ del pipe_pose
237
+ torch.cuda.empty_cache()
238
+
239
+ def load_canny_pipeline():
240
+ global pipe_canny
241
+ controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device)
242
+ pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline(
243
+ vae=vae,
244
+ controlnet=controlnet_canny,
245
+ text_encoder=text_encoder,
246
+ tokenizer=tokenizer,
247
+ unet=unet,
248
+ scheduler=scheduler,
249
+ image_encoder=image_encoder,
250
+ feature_extractor=clip_image_processor,
251
+ force_zeros_for_empty_prompt=False
252
+ )
253
+ pipe_canny.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
254
+
255
+ def load_depth_pipeline():
256
+ global pipe_depth
257
+ controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
258
+ pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
259
+ vae=vae,
260
+ controlnet=controlnet_depth,
261
+ text_encoder=text_encoder,
262
+ tokenizer=tokenizer,
263
+ unet=unet,
264
+ scheduler=scheduler,
265
+ image_encoder=image_encoder,
266
+ feature_extractor=clip_image_processor,
267
+ force_zeros_for_empty_prompt=False
268
+ )
269
+ pipe_depth.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
270
+
271
+ def load_pose_pipeline():
272
+ global pipe_pose
273
+ controlnet_pose = ControlNetModel.from_pretrained(f"{ckpt_dir_pose}", revision=None).half().to(device)
274
+ pipe_pose = StableDiffusionXLControlNetImg2ImgPipeline(
275
+ vae=vae,
276
+ controlnet=controlnet_pose,
277
+ text_encoder=text_encoder,
278
+ tokenizer=tokenizer,
279
+ unet=unet,
280
+ scheduler=scheduler,
281
+ image_encoder=image_encoder,
282
+ feature_extractor=clip_image_processor,
283
+ force_zeros_for_empty_prompt=False
284
+ )
285
+ pipe_pose.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
286
+
287
+ def switch_to_canny():
288
+ clear_resources()
289
+ load_canny_pipeline()
290
+ return gr.update(visible=True)
291
+
292
+ def switch_to_depth():
293
+ clear_resources()
294
+ load_depth_pipeline()
295
+ return gr.update(visible=True)
296
+
297
+ def switch_to_pose():
298
+ clear_resources()
299
+ load_pose_pipeline()
300
+ return gr.update(visible=True)
301
+
302
+ with gr.Blocks(css=css) as Kolors:
303
+ gr.HTML(load_description("assets/title.md"))
304
+ with gr.Row():
305
+ with gr.Column(elem_id="col-left"):
306
+ with gr.Row():
307
+ prompt = gr.Textbox(
308
+ label="Prompt",
309
+ placeholder="Enter your prompt",
310
+ lines=2
311
+ )
312
+ with gr.Row():
313
+ image = gr.Image(label="Image", type="pil")
314
+ ipa_image = gr.Image(label="IP-Adapter-Image", type="pil")
315
+ with gr.Row():
316
+ num_images = gr.Slider(
317
+ label="Number of Images",
318
+ minimum=1,
319
+ maximum=10,
320
+ step=1,
321
+ value=1,
322
+ )
323
+ with gr.Accordion("Advanced Settings", open=False):
324
+ negative_prompt = gr.Textbox(
325
+ label="Negative prompt",
326
+ placeholder="Enter a negative prompt",
327
+ visible=True,
328
+ value="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ"
329
+ )
330
+ seed = gr.Slider(
331
+ label="Seed",
332
+ minimum=0,
333
+ maximum=MAX_SEED,
334
+ step=1,
335
+ value=0,
336
+ )
337
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
338
+ with gr.Row():
339
+ guidance_scale = gr.Slider(
340
+ label="Guidance scale",
341
+ minimum=0.0,
342
+ maximum=10.0,
343
+ step=0.1,
344
+ value=5.0,
345
+ )
346
+ num_inference_steps = gr.Slider(
347
+ label="Number of inference steps",
348
+ minimum=10,
349
+ maximum=50,
350
+ step=1,
351
+ value=30,
352
+ )
353
+ with gr.Row():
354
+ controlnet_conditioning_scale = gr.Slider(
355
+ label="Controlnet Conditioning Scale",
356
+ minimum=0.0,
357
+ maximum=1.0,
358
+ step=0.1,
359
+ value=0.5,
360
+ )
361
+ control_guidance_end = gr.Slider(
362
+ label="Control Guidance End",
363
+ minimum=0.0,
364
+ maximum=1.0,
365
+ step=0.1,
366
+ value=0.9,
367
+ )
368
+ with gr.Row():
369
+ strength = gr.Slider(
370
+ label="Strength",
371
+ minimum=0.0,
372
+ maximum=1.0,
373
+ step=0.1,
374
+ value=1.0,
375
+ )
376
+ ip_scale = gr.Slider(
377
+ label="IP_Scale",
378
+ minimum=0.0,
379
+ maximum=1.0,
380
+ step=0.1,
381
+ value=0.5,
382
+ )
383
+ with gr.Row():
384
+ canny_button = gr.Button("Canny", elem_id="button")
385
+ depth_button = gr.Button("Depth", elem_id="button")
386
+ pose_button = gr.Button("Pose", elem_id="button")
387
+
388
+ with gr.Column(elem_id="col-right"):
389
+ result = gr.Gallery(label="Result", show_label=False, columns=3)
390
+ seed_used = gr.Number(label="Seed Used")
391
+
392
+ with gr.Row():
393
+ gr.Examples(
394
+ fn=infer_canny,
395
+ examples=canny_examples,
396
+ inputs=[prompt, image, ipa_image, num_images],
397
+ outputs=[result, seed_used],
398
+ label="Canny"
399
+ )
400
+ with gr.Row():
401
+ gr.Examples(
402
+ fn=infer_depth,
403
+ examples=depth_examples,
404
+ inputs=[prompt, image, ipa_image, num_images],
405
+ outputs=[result, seed_used],
406
+ label="Depth"
407
+ )
408
+ with gr.Row():
409
+ gr.Examples(
410
+ fn=infer_pose,
411
+ examples=pose_examples,
412
+ inputs=[prompt, image, ipa_image, num_images],
413
+ outputs=[result, seed_used],
414
+ label="Pose"
415
+ )
416
+
417
+ canny_button.click(
418
+ fn=switch_to_canny,
419
+ outputs=[canny_button]
420
+ ).then(
421
+ fn=infer_canny,
422
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale, num_images],
423
+ outputs=[result, seed_used]
424
+ )
425
+
426
+ depth_button.click(
427
+ fn=switch_to_depth,
428
+ outputs=[depth_button]
429
+ ).then(
430
+ fn=infer_depth,
431
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale, num_images],
432
+ outputs=[result, seed_used]
433
+ )
434
+
435
+ pose_button.click(
436
+ fn=switch_to_pose,
437
+ outputs=[pose_button]
438
+ ).then(
439
+ fn=infer_pose,
440
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale, num_images],
441
+ outputs=[result, seed_used]
442
+ )
443
+
444
+ Kolors.queue().launch(debug=True, share=True)