XuDongZhou commited on
Commit
4375fb1
·
verified ·
1 Parent(s): 1633b7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -106
app.py CHANGED
@@ -1,154 +1,402 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
 
 
 
26
  prompt,
27
  negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
  guidance_scale,
33
- num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- image = pipe(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
 
 
 
 
 
44
  guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
  height=height,
48
- generator=generator,
49
- ).images[0]
 
50
 
51
- return image, seed
52
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
  css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
  """
66
-
67
  with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
  )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
  )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
 
94
  minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
-
119
- with gr.Row():
120
  guidance_scale = gr.Slider(
121
  label="Guidance scale",
122
- minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
1
+ import cv2, os
2
+ import torch
3
  import random
4
+ import numpy as np
5
 
6
+ import spaces
 
 
7
 
8
+ import PIL
9
+ from PIL import Image
10
+ from typing import Tuple
11
 
12
+ import diffusers
13
+ from diffusers.utils import load_image
 
 
14
 
15
+ from diffusers import (
16
+ AutoencoderKL,
17
+ UNet2DConditionModel,
18
+ UniPCMultistepScheduler,
19
+ )
20
 
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ from insightface.app import FaceAnalysis
24
+
25
+ from pipeline_controlnet_xs_sd_xl_instantid import StableDiffusionXLInstantIDXSPipeline, UNetControlNetXSModel
26
+
27
+ from utils.controlnet_xs import ControlNetXSAdapter
28
+ # from controlnet_aux import OpenposeDetector
29
+
30
+ import gradio as gr
31
+
32
+ import torch.nn.functional as F
33
+ from torchvision.transforms import Compose
34
+
35
+ # global variable
36
  MAX_SEED = np.iinfo(np.int32).max
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ weight_dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
39
+
40
+
41
+ base_model = 'frankjoshua/realvisxlV40_v40Bakedvae'
42
+ vae_path = 'madebyollin/sdxl-vae-fp16-fix'
43
+ ckpt = 'RED-AIGC/InstantID-XS'
44
+
45
+ image_proj_path = os.path.join(ckpt, "image_proj.bin")
46
+ cnxs_path = os.path.join(ckpt, "controlnetxs.bin")
47
+ cross_attn_path = os.path.join(ckpt, "cross_attn.bin")
48
+
49
+
50
+ # Load face encoder
51
+ app = FaceAnalysis(
52
+ name="antelopev2",
53
+ root="./",
54
+ providers=["CPUExecutionProvider"],
55
+ )
56
+ app.prepare(ctx_id=0, det_size=(640, 640))
57
+
58
+
59
+ def get_ControlNetXS(base_model, cnxs_path, device, size_ratio=0.125, weight_dtype=torch.float16):
60
+ unet = UNet2DConditionModel.from_pretrained(base_model, subfolder="unet").to(device, dtype=weight_dtype)
61
+ controlnet = ControlNetXSAdapter.from_unet(unet, size_ratio=size_ratio, learn_time_embedding=True)
62
+ state_dict = torch.load(cnxs_path, map_location="cpu", weights_only=True)
63
+ ctrl_state_dict = {}
64
+ for key, value in state_dict.items():
65
+ if 'attn2.processor' not in key:
66
+ if 'ctrl_' in key and 'ctrl_to_base' not in key:
67
+ key = key.replace('ctrl_', '')
68
+ if 'up_blocks' in key:
69
+ key = key.replace('up_blocks', 'up_connections')
70
+ ctrl_state_dict[key] = value
71
+ controlnet.load_state_dict(ctrl_state_dict, strict=True)
72
+ controlnet.to(device, dtype=weight_dtype)
73
+ ControlNetXS = UNetControlNetXSModel.from_unet(unet, controlnet).to(device, dtype=weight_dtype)
74
+
75
+ return ControlNetXS
76
+
77
+ ControlNetXS = get_ControlNetXS(base_model, cnxs_path, device, size_ratio=0.125, weight_dtype=weight_dtype)
78
+ vae = AutoencoderKL.from_pretrained(vae_path)
79
+ pipe = StableDiffusionXLInstantIDXSPipeline.from_pretrained(
80
+ pretrained_model_name_or_path,
81
+ vae=vae,
82
+ unet=ControlNetXS,
83
+ controlnet=None,
84
+ torch_dtype=weight_dtype,
85
+ )
86
+
87
+ pipe.cuda(device=device, dtype=weight_dtype, use_xformers=True)
88
+ pipe.load_ip_adapter(image_proj_path, cross_attn_path)
89
+
90
+ pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
91
+ pipe.unet.config.ctrl_learn_time_embedding = True
92
+ pipe = pipe.to(args.device)
93
+
94
+
95
+
96
+
97
+ def toggle_lcm_ui(value):
98
+ if value:
99
+ return (
100
+ gr.update(minimum=0, maximum=100, step=1, value=5),
101
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=1.5),
102
+ )
103
+ else:
104
+ return (
105
+ gr.update(minimum=5, maximum=100, step=1, value=30),
106
+ gr.update(minimum=0.1, maximum=20.0, step=0.1, value=5),
107
+ )
108
+
109
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
110
+ if randomize_seed:
111
+ seed = random.randint(0, MAX_SEED)
112
+ return seed
113
+
114
+ def remove_tips():
115
+ return gr.update(visible=False)
116
+
117
+ def get_example():
118
+ case = [
119
+ [
120
+ "./examples/1.jpg",
121
+ None,
122
+ "a woman,(looking at the viewer), portrait, daily wear, 8K texture, realistic, symmetrical hyperdetailed texture, masterpiece, enhanced details, (eye highlight:2), perfect composition, natural lighting, best quality, authentic, natural posture",
123
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
124
+ ],
125
+ [
126
+ "./examples/1.jpeg",
127
+ "./examples/poses/pose1.jpg",
128
+ "a woman,(looking at the viewer), portrait, daily wear, 8K texture, realistic, symmetrical hyperdetailed texture, masterpiece, enhanced details, (eye highlight:2), perfect composition, natural lighting, best quality, authentic, natural posture",
129
+ "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, photo, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
130
+ ],
131
+
132
+ return case
133
+
134
+ def run_for_examples(face_file, pose_file, prompt, style, negative_prompt):
135
+ return generate_image(
136
+ face_file,
137
+ pose_file,
138
+ prompt,
139
+ negative_prompt,
140
+ 20, # num_steps
141
+ 0.8, # identitynet_strength_ratio
142
+ 0.8, # adapter_strength_ratio
143
+ 0.8, # pose_strength
144
+ 5.0, # guidance_scale
145
+ 42, # seed
146
+ )
147
+
148
+ def convert_from_cv2_to_image(img: np.ndarray) -> Image:
149
+ return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
150
+
151
+ def convert_from_image_to_cv2(img: Image) -> np.ndarray:
152
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
153
+
154
+ def resize_img(
155
+ input_image,
156
+ max_side=1280,
157
+ min_side=1024,
158
+ size=None,
159
+ pad_to_max_side=False,
160
+ mode=PIL.Image.BILINEAR,
161
+ base_pixel_number=64,
162
+ ):
163
+ w, h = input_image.size
164
+ if size is not None:
165
+ w_resize_new, h_resize_new = size
166
+ else:
167
+ ratio = min_side / min(h, w)
168
+ w, h = round(ratio * w), round(ratio * h)
169
+ ratio = max_side / max(h, w)
170
+ input_image = input_image.resize([round(ratio * w), round(ratio * h)], mode)
171
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
172
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
173
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
174
 
175
+ if pad_to_max_side:
176
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
177
+ offset_x = (max_side - w_resize_new) // 2
178
+ offset_y = (max_side - h_resize_new) // 2
179
+ res[
180
+ offset_y : offset_y + h_resize_new, offset_x : offset_x + w_resize_new
181
+ ] = np.array(input_image)
182
+ input_image = Image.fromarray(res)
183
+ return input_image
184
 
185
+
186
+ @spaces.GPU
187
+ def generate_image(
188
+ face_image_path,
189
+ pose_image_path,
190
  prompt,
191
  negative_prompt,
192
+ num_steps,
193
+ controlnet_conditioning_scale,
194
+ adapter_strength_ratio,
 
195
  guidance_scale,
196
+ seed,
197
  progress=gr.Progress(track_tqdm=True),
198
  ):
 
 
199
 
200
+ if face_image_path is None:
201
+ raise gr.Error(
202
+ f"Cannot find any input face image! Please upload the face image"
203
+ )
204
+
205
+ if prompt is None:
206
+ prompt = "a person"
207
+
208
+ # apply the style template
209
+ prompt, negative_prompt = apply_style(style_name, prompt, negative_prompt)
210
+
211
+ face_image = load_image(face_image_path)
212
+ face_image = resize_img(face_image, max_side=1024)
213
+ face_image_cv2 = convert_from_image_to_cv2(face_image)
214
+ height, width, _ = face_image_cv2.shape
215
+
216
+ # Extract face features
217
+ face_info = app.get(face_image_cv2)
218
+
219
+ if len(face_info) == 0:
220
+ raise gr.Error(
221
+ f"Unable to detect a face in the image. Please upload a different photo with a clear face."
222
+ )
223
+
224
+ face_info = sorted(
225
+ face_info,
226
+ key=lambda x: (x["bbox"][2] - x["bbox"][0]) * x["bbox"][3] - x["bbox"][1],
227
+ )[-1] # only use the maximum face
228
 
229
+ face_emb = torch.from_numpy(face_info.normed_embedding)
230
+ face_kps = draw_kps(convert_from_cv2_to_image(face_image_cv2), face_info["kps"])
231
+ img_controlnet = face_image
232
+ if pose_image_path is not None:
233
+ pose_image = load_image(pose_image_path)
234
+ pose_image = resize_img(pose_image, max_side=1024)
235
+ img_controlnet = pose_image
236
+ pose_image_cv2 = convert_from_image_to_cv2(pose_image)
237
+
238
+ face_info = app.get(pose_image_cv2)
239
+
240
+ if len(face_info) == 0:
241
+ raise gr.Error(
242
+ f"Cannot find any face in the reference image! Please upload another person image"
243
+ )
244
+
245
+ face_info = face_info[-1]
246
+ face_kps = draw_kps(pose_image, face_info["kps"])
247
+
248
+ width, height = face_kps.size
249
+
250
+ print("Start inference...")
251
+ print(f"[Debug] Prompt: {prompt}, \n[Debug] Neg Prompt: {negative_prompt}")
252
+
253
+ pipe.set_ip_adapter_scale(adapter_strength_ratio)
254
+ images = pipe(
255
  prompt=prompt,
256
  negative_prompt=negative_prompt,
257
+ image_embeds=face_emb,
258
+ image=face_kps,
259
+ control_mask=control_mask,
260
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
261
+ num_inference_steps=num_steps,
262
  guidance_scale=guidance_scale,
 
 
263
  height=height,
264
+ width=width,
265
+ generator=torch.Generator(device=device).manual_seed(seed),
266
+ ).images
267
 
268
+ return images[0], gr.update(visible=True)
269
 
270
 
 
 
 
 
 
271
 
272
  css = """
273
+ .gradio-container {width: 85% !important}
 
 
 
274
  """
 
275
  with gr.Blocks(css=css) as demo:
276
+ # description
277
+ gr.Markdown(title)
278
+ gr.Markdown(description)
279
 
280
+ with gr.Row():
281
+ with gr.Column():
282
+ with gr.Row(equal_height=True):
283
+ # upload face image
284
+ face_file = gr.Image(
285
+ label="Upload a photo of your face", type="filepath"
286
+ )
287
+ # optional: upload a reference pose image
288
+ pose_file = gr.Image(
289
+ label="Upload a reference pose image (Optional)",
290
+ type="filepath",
291
+ )
292
+
293
+ # prompt
294
+ prompt = gr.Textbox(
295
  label="Prompt",
296
+ info="Give simple prompt is enough to achieve good face fidelity",
297
+ placeholder="A photo of a person",
298
+ value="",
 
299
  )
300
 
301
+ submit = gr.Button("Submit", variant="primary")
302
+ enable_LCM = gr.Checkbox(
303
+ label="Enable Fast Inference with LCM", value=enable_lcm_arg,
304
+ info="LCM speeds up the inference step, the trade-off is the quality of the generated image. It performs better with portrait face images rather than distant faces",
 
 
 
 
 
 
305
  )
306
 
307
+ # strength
308
+ controlnet_conditioning_scale = gr.Slider(
309
+ label="IdentityNet strength (for fidelity)",
310
  minimum=0,
311
+ maximum=1.0,
312
+ step=0.1,
313
+ value=0.8,
314
+ )
315
+ adapter_strength_ratio = gr.Slider(
316
+ label="Image adapter strength (for detail)",
317
+ minimum=0,
318
+ maximum=1.2,
319
+ step=0.1,
320
+ value=0.8,
321
  )
322
 
323
+ with gr.Accordion(open=False, label="Advanced Options"):
324
+ negative_prompt = gr.Textbox(
325
+ label="Negative Prompt",
326
+ placeholder="low quality",
327
+ value="(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed cat, deformed, photo, anthropomorphic cat, monochrome, pet collar, gun, weapon, blue, 3d, drones, drone, buildings in background, green",
 
 
 
 
328
  )
329
+ num_steps = gr.Slider(
330
+ label="Number of sample steps",
331
+ minimum=1,
332
+ maximum=100,
333
+ step=1,
334
+ value=20,
 
335
  )
 
 
336
  guidance_scale = gr.Slider(
337
  label="Guidance scale",
338
+ minimum=0.1,
339
  maximum=10.0,
340
  step=0.1,
341
+ value=5.0,
342
  )
343
+ seed = gr.Slider(
344
+ label="Seed",
345
+ minimum=0,
346
+ maximum=MAX_SEED,
 
347
  step=1,
348
+ value=42,
349
  )
350
 
351
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
352
+
353
+ with gr.Column(scale=1):
354
+ gallery = gr.Image(label="Generated Images")
355
+ usage_tips = gr.Markdown(
356
+ label="InstantID Usage Tips", value=tips, visible=False
357
+ )
358
+
359
+ submit.click(
360
+ fn=remove_tips,
361
+ outputs=usage_tips,
362
+ ).then(
363
+ fn=randomize_seed_fn,
364
+ inputs=[seed, randomize_seed],
365
+ outputs=seed,
366
+ queue=False,
367
+ api_name=False,
368
+ ).then(
369
+ fn=generate_image,
370
+ inputs=[
371
+ face_file,
372
+ pose_file,
373
+ prompt,
374
+ negative_prompt,
375
+ num_steps,
376
+ controlnet_conditioning_scale,
377
+ adapter_strength_ratio,
378
+ guidance_scale,
379
+ seed,
380
+ ],
381
+ outputs=[gallery, usage_tips],
382
+ )
383
+
384
+ enable_LCM.input(
385
+ fn=toggle_lcm_ui,
386
+ inputs=[enable_LCM],
387
+ outputs=[num_steps, guidance_scale],
388
+ queue=False,
389
+ )
390
+
391
+ gr.Examples(
392
+ examples=get_example(),
393
+ inputs=[face_file, pose_file, prompt, negative_prompt],
394
+ fn=run_for_examples,
395
+ outputs=[gallery, usage_tips],
396
+ cache_examples=True,
397
  )
398
 
399
+ gr.Markdown(article)
400
+
401
+ demo.queue(api_open=False)
402
+ demo.launch()