svjack commited on
Commit
578f864
ยท
verified ยท
1 Parent(s): c16a841

Create switch_app.py

Browse files
Files changed (1) hide show
  1. switch_app.py +411 -0
switch_app.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ import cv2
4
+ import gradio as gr
5
+ import numpy as np
6
+ from huggingface_hub import snapshot_download
7
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
8
+ from diffusers.utils import load_image
9
+ from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
10
+ from kolors.models.modeling_chatglm import ChatGLMModel
11
+ from kolors.models.tokenization_chatglm import ChatGLMTokenizer
12
+ from kolors.models.controlnet import ControlNetModel
13
+ from diffusers import AutoencoderKL
14
+ from kolors.models.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers import EulerDiscreteScheduler
16
+ from PIL import Image
17
+ from annotator.midas import MidasDetector
18
+ from annotator.dwpose import DWposeDetector
19
+ from annotator.util import resize_image, HWC3
20
+
21
+ device = "cuda"
22
+ ckpt_dir = snapshot_download(repo_id="Kwai-Kolors/Kolors")
23
+ ckpt_dir_depth = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Depth")
24
+ ckpt_dir_canny = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Canny")
25
+ ckpt_dir_ipa = snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus")
26
+ ckpt_dir_pose = snapshot_download(repo_id="Kwai-Kolors/Kolors-ControlNet-Pose")
27
+
28
+ text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
29
+ tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
30
+ vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
31
+ scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
32
+ unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
33
+
34
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ckpt_dir_ipa}/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16, device=device)
35
+ ip_img_size = 336
36
+ clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
37
+
38
+ model_midas = MidasDetector()
39
+ model_dwpose = DWposeDetector()
40
+
41
+ MAX_SEED = np.iinfo(np.int32).max
42
+ MAX_IMAGE_SIZE = 1024
43
+
44
+ def process_canny_condition(image, canny_threods=[100, 200]):
45
+ np_image = image.copy()
46
+ np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
47
+ np_image = np_image[:, :, None]
48
+ np_image = np.concatenate([np_image, np_image, np_image], axis=2)
49
+ np_image = HWC3(np_image)
50
+ return Image.fromarray(np_image)
51
+
52
+ def process_depth_condition_midas(img, res=1024):
53
+ h, w, _ = img.shape
54
+ img = resize_image(HWC3(img), res)
55
+ result = HWC3(model_midas(img))
56
+ result = cv2.resize(result, (w, h))
57
+ return Image.fromarray(result)
58
+
59
+ def process_dwpose_condition(image, res=1024):
60
+ h, w, _ = image.shape
61
+ img = resize_image(HWC3(image), res)
62
+ out_res, out_img = model_dwpose(image)
63
+ result = HWC3(out_img)
64
+ result = cv2.resize(result, (w, h))
65
+ return Image.fromarray(result)
66
+
67
+ def infer_canny(prompt,
68
+ image=None,
69
+ ipa_img=None,
70
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ",
71
+ seed=66,
72
+ randomize_seed=False,
73
+ guidance_scale=5.0,
74
+ num_inference_steps=50,
75
+ controlnet_conditioning_scale=0.5,
76
+ control_guidance_end=0.9,
77
+ strength=1.0,
78
+ ip_scale=0.5,
79
+ ):
80
+ if randomize_seed:
81
+ seed = random.randint(0, MAX_SEED)
82
+ generator = torch.Generator().manual_seed(seed)
83
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
84
+ pipe = pipe_canny.to("cuda")
85
+ pipe.set_ip_adapter_scale([ip_scale])
86
+ condi_img = process_canny_condition(np.array(init_image))
87
+ image = pipe(
88
+ prompt=prompt,
89
+ image=init_image,
90
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
91
+ control_guidance_end=control_guidance_end,
92
+ ip_adapter_image=[ipa_img],
93
+ strength=strength,
94
+ control_image=condi_img,
95
+ negative_prompt=negative_prompt,
96
+ num_inference_steps=num_inference_steps,
97
+ guidance_scale=guidance_scale,
98
+ num_images_per_prompt=1,
99
+ generator=generator,
100
+ ).images[0]
101
+ return [condi_img, image], seed
102
+
103
+ def infer_depth(prompt,
104
+ image=None,
105
+ ipa_img=None,
106
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ",
107
+ seed=66,
108
+ randomize_seed=False,
109
+ guidance_scale=5.0,
110
+ num_inference_steps=50,
111
+ controlnet_conditioning_scale=0.5,
112
+ control_guidance_end=0.9,
113
+ strength=1.0,
114
+ ip_scale=0.5,
115
+ ):
116
+ if randomize_seed:
117
+ seed = random.randint(0, MAX_SEED)
118
+ generator = torch.Generator().manual_seed(seed)
119
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
120
+ pipe = pipe_depth.to("cuda")
121
+ pipe.set_ip_adapter_scale([ip_scale])
122
+ condi_img = process_depth_condition_midas(np.array(init_image), MAX_IMAGE_SIZE)
123
+ image = pipe(
124
+ prompt=prompt,
125
+ image=init_image,
126
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
127
+ control_guidance_end=control_guidance_end,
128
+ ip_adapter_image=[ipa_img],
129
+ strength=strength,
130
+ control_image=condi_img,
131
+ negative_prompt=negative_prompt,
132
+ num_inference_steps=num_inference_steps,
133
+ guidance_scale=guidance_scale,
134
+ num_images_per_prompt=1,
135
+ generator=generator,
136
+ ).images[0]
137
+ return [condi_img, image], seed
138
+
139
+ def infer_pose(prompt,
140
+ image=None,
141
+ ipa_img=None,
142
+ negative_prompt="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒjpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ",
143
+ seed=66,
144
+ randomize_seed=False,
145
+ guidance_scale=5.0,
146
+ num_inference_steps=50,
147
+ controlnet_conditioning_scale=0.5,
148
+ control_guidance_end=0.9,
149
+ strength=1.0,
150
+ ip_scale=0.5,
151
+ ):
152
+ if randomize_seed:
153
+ seed = random.randint(0, MAX_SEED)
154
+ generator = torch.Generator().manual_seed(seed)
155
+ init_image = resize_image(image, MAX_IMAGE_SIZE)
156
+ pipe = pipe_pose.to("cuda")
157
+ pipe.set_ip_adapter_scale([ip_scale])
158
+ condi_img = process_dwpose_condition(np.array(init_image), MAX_IMAGE_SIZE)
159
+ image = pipe(
160
+ prompt=prompt,
161
+ image=init_image,
162
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
163
+ control_guidance_end=control_guidance_end,
164
+ ip_adapter_image=[ipa_img],
165
+ strength=strength,
166
+ control_image=condi_img,
167
+ negative_prompt=negative_prompt,
168
+ num_inference_steps=num_inference_steps,
169
+ guidance_scale=guidance_scale,
170
+ num_images_per_prompt=1,
171
+ generator=generator,
172
+ ).images[0]
173
+ return [condi_img, image], seed
174
+
175
+ canny_examples = [
176
+ ["ไธ€ไธช็บข่‰ฒๅคดๅ‘็š„ๅฅณๅญฉ๏ผŒๅ”ฏ็พŽ้ฃŽๆ™ฏ๏ผŒๆธ…ๆ–ฐๆ˜Žไบฎ๏ผŒๆ–‘้ฉณ็š„ๅ…‰ๅฝฑ๏ผŒๆœ€ๅฅฝ็š„่ดจ้‡๏ผŒ่ถ…็ป†่Š‚๏ผŒ8K็”ป่ดจ",
177
+ "image/woman_2.png", "image/2.png"],
178
+ ]
179
+
180
+ depth_examples = [
181
+ ["ไธ€ไธชๆผ‚ไบฎ็š„ๅฅณๅญฉ๏ผŒๆœ€ๅฅฝ็š„่ดจ้‡๏ผŒ่ถ…็ป†่Š‚๏ผŒ8K็”ป่ดจ",
182
+ "image/1.png", "image/woman_1.png"],
183
+ ]
184
+
185
+ pose_examples = [
186
+ ["ไธ€ไฝ็ฉฟ็€็ดซ่‰ฒๆณกๆณก่ข–่ฟž่กฃ่ฃ™ใ€ๆˆด็€็š‡ๅ† ๅ’Œ็™ฝ่‰ฒ่•พไธๆ‰‹ๅฅ—็š„ๅฅณๅญฉ๏ผŒ่ถ…้ซ˜ๅˆ†่พจ็Ž‡๏ผŒๆœ€ไฝณๅ“่ดจ๏ผŒ8k็”ป่ดจ",
187
+ "image/woman_3.png", "image/woman_4.png"],
188
+ ]
189
+
190
+ css = """
191
+ #col-left {
192
+ margin: 0 auto;
193
+ max-width: 600px;
194
+ }
195
+ #col-right {
196
+ margin: 0 auto;
197
+ max-width: 750px;
198
+ }
199
+ #button {
200
+ color: blue;
201
+ }
202
+ """
203
+
204
+ def load_description(fp):
205
+ with open(fp, 'r', encoding='utf-8') as f:
206
+ content = f.read()
207
+ return content
208
+
209
+ def clear_resources():
210
+ global pipe_canny, pipe_depth, pipe_pose
211
+ if 'pipe_canny' in globals():
212
+ del pipe_canny
213
+ if 'pipe_depth' in globals():
214
+ del pipe_depth
215
+ if 'pipe_pose' in globals():
216
+ del pipe_pose
217
+ torch.cuda.empty_cache()
218
+
219
+ def load_canny_pipeline():
220
+ global pipe_canny
221
+ controlnet_canny = ControlNetModel.from_pretrained(f"{ckpt_dir_canny}", revision=None).half().to(device)
222
+ pipe_canny = StableDiffusionXLControlNetImg2ImgPipeline(
223
+ vae=vae,
224
+ controlnet=controlnet_canny,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ unet=unet,
228
+ scheduler=scheduler,
229
+ image_encoder=image_encoder,
230
+ feature_extractor=clip_image_processor,
231
+ force_zeros_for_empty_prompt=False
232
+ )
233
+ pipe_canny.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
234
+
235
+ def load_depth_pipeline():
236
+ global pipe_depth
237
+ controlnet_depth = ControlNetModel.from_pretrained(f"{ckpt_dir_depth}", revision=None).half().to(device)
238
+ pipe_depth = StableDiffusionXLControlNetImg2ImgPipeline(
239
+ vae=vae,
240
+ controlnet=controlnet_depth,
241
+ text_encoder=text_encoder,
242
+ tokenizer=tokenizer,
243
+ unet=unet,
244
+ scheduler=scheduler,
245
+ image_encoder=image_encoder,
246
+ feature_extractor=clip_image_processor,
247
+ force_zeros_for_empty_prompt=False
248
+ )
249
+ pipe_depth.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
250
+
251
+ def load_pose_pipeline():
252
+ global pipe_pose
253
+ controlnet_pose = ControlNetModel.from_pretrained(f"{ckpt_dir_pose}", revision=None).half().to(device)
254
+ pipe_pose = StableDiffusionXLControlNetImg2ImgPipeline(
255
+ vae=vae,
256
+ controlnet=controlnet_pose,
257
+ text_encoder=text_encoder,
258
+ tokenizer=tokenizer,
259
+ unet=unet,
260
+ scheduler=scheduler,
261
+ image_encoder=image_encoder,
262
+ feature_extractor=clip_image_processor,
263
+ force_zeros_for_empty_prompt=False
264
+ )
265
+ pipe_pose.load_ip_adapter(f'{ckpt_dir_ipa}', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
266
+
267
+ def switch_to_canny():
268
+ clear_resources()
269
+ load_canny_pipeline()
270
+ return gr.update(visible=True)
271
+
272
+ def switch_to_depth():
273
+ clear_resources()
274
+ load_depth_pipeline()
275
+ return gr.update(visible=True)
276
+
277
+ def switch_to_pose():
278
+ clear_resources()
279
+ load_pose_pipeline()
280
+ return gr.update(visible=True)
281
+
282
+ with gr.Blocks(css=css) as Kolors:
283
+ gr.HTML(load_description("assets/title.md"))
284
+ with gr.Row():
285
+ with gr.Column(elem_id="col-left"):
286
+ with gr.Row():
287
+ prompt = gr.Textbox(
288
+ label="Prompt",
289
+ placeholder="Enter your prompt",
290
+ lines=2
291
+ )
292
+ with gr.Row():
293
+ image = gr.Image(label="Image", type="pil")
294
+ ipa_image = gr.Image(label="IP-Adapter-Image", type="pil")
295
+ with gr.Accordion("Advanced Settings", open=False):
296
+ negative_prompt = gr.Textbox(
297
+ label="Negative prompt",
298
+ placeholder="Enter a negative prompt",
299
+ visible=True,
300
+ value="nsfw๏ผŒ่„ธ้ƒจ้˜ดๅฝฑ๏ผŒไฝŽๅˆ†่พจ็Ž‡๏ผŒ็ณŸ็ณ•็š„่งฃๅ‰–็ป“ๆž„ใ€็ณŸ็ณ•็š„ๆ‰‹๏ผŒ็ผบๅคฑๆ‰‹ๆŒ‡ใ€่ดจ้‡ๆœ€ๅทฎใ€ไฝŽ่ดจ้‡ใ€jpegไผชๅฝฑใ€ๆจก็ณŠใ€็ณŸ็ณ•๏ผŒ้ป‘่„ธ๏ผŒ้œ“่™น็ฏ"
301
+ )
302
+ seed = gr.Slider(
303
+ label="Seed",
304
+ minimum=0,
305
+ maximum=MAX_SEED,
306
+ step=1,
307
+ value=0,
308
+ )
309
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
310
+ with gr.Row():
311
+ guidance_scale = gr.Slider(
312
+ label="Guidance scale",
313
+ minimum=0.0,
314
+ maximum=10.0,
315
+ step=0.1,
316
+ value=5.0,
317
+ )
318
+ num_inference_steps = gr.Slider(
319
+ label="Number of inference steps",
320
+ minimum=10,
321
+ maximum=50,
322
+ step=1,
323
+ value=30,
324
+ )
325
+ with gr.Row():
326
+ controlnet_conditioning_scale = gr.Slider(
327
+ label="Controlnet Conditioning Scale",
328
+ minimum=0.0,
329
+ maximum=1.0,
330
+ step=0.1,
331
+ value=0.5,
332
+ )
333
+ control_guidance_end = gr.Slider(
334
+ label="Control Guidance End",
335
+ minimum=0.0,
336
+ maximum=1.0,
337
+ step=0.1,
338
+ value=0.9,
339
+ )
340
+ with gr.Row():
341
+ strength = gr.Slider(
342
+ label="Strength",
343
+ minimum=0.0,
344
+ maximum=1.0,
345
+ step=0.1,
346
+ value=1.0,
347
+ )
348
+ ip_scale = gr.Slider(
349
+ label="IP_Scale",
350
+ minimum=0.0,
351
+ maximum=1.0,
352
+ step=0.1,
353
+ value=0.5,
354
+ )
355
+ with gr.Row():
356
+ canny_button = gr.Button("Canny", elem_id="button")
357
+ depth_button = gr.Button("Depth", elem_id="button")
358
+ pose_button = gr.Button("Pose", elem_id="button")
359
+
360
+ with gr.Column(elem_id="col-right"):
361
+ result = gr.Gallery(label="Result", show_label=False, columns=2)
362
+ seed_used = gr.Number(label="Seed Used")
363
+
364
+ with gr.Row():
365
+ gr.Examples(
366
+ fn=infer_canny,
367
+ examples=canny_examples,
368
+ inputs=[prompt, image, ipa_image],
369
+ outputs=[result, seed_used],
370
+ label="Canny"
371
+ )
372
+ with gr.Row():
373
+ gr.Examples(
374
+ fn=infer_depth,
375
+ examples=depth_examples,
376
+ inputs=[prompt, image, ipa_image],
377
+ outputs=[result, seed_used],
378
+ label="Depth"
379
+ )
380
+ with gr.Row():
381
+ gr.Examples(
382
+ fn=infer_pose,
383
+ examples=pose_examples,
384
+ inputs=[prompt, image, ipa_image],
385
+ outputs=[result, seed_used],
386
+ label="Pose"
387
+ )
388
+
389
+ canny_button.click(
390
+ fn=infer_canny,
391
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale],
392
+ outputs=[result, seed_used]
393
+ )
394
+
395
+ depth_button.click(
396
+ fn=infer_depth,
397
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale],
398
+ outputs=[result, seed_used]
399
+ )
400
+
401
+ pose_button.click(
402
+ fn=infer_pose,
403
+ inputs=[prompt, image, ipa_image, negative_prompt, seed, randomize_seed, guidance_scale, num_inference_steps, controlnet_conditioning_scale, control_guidance_end, strength, ip_scale],
404
+ outputs=[result, seed_used]
405
+ )
406
+
407
+ canny_button.click(switch_to_canny, outputs=[canny_button])
408
+ depth_button.click(switch_to_depth, outputs=[depth_button])
409
+ pose_button.click(switch_to_pose, outputs=[pose_button])
410
+
411
+ Kolors.queue().launch(debug=True, share=True)