adaface-neurips commited on
Commit
3558684
1 Parent(s): d08dbd0

update app.py, add guidance scale slider

Browse files
Files changed (2) hide show
  1. app.py +15 -5
  2. lib/pipline_ConsistentID.py +2 -3
app.py CHANGED
@@ -22,7 +22,7 @@ device = f"cuda:{args.gpu}"
22
  pipe = ConsistentIDPipeline.from_pretrained(
23
  args.base_model_path,
24
  torch_dtype=torch.float16,
25
- ).to(device)
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
@@ -30,11 +30,12 @@ pipe.load_ConsistentID_model(
30
  bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
 
33
 
34
  @spaces.GPU
35
  def process(selected_template_images, custom_image, prompt,
36
  negative_prompt, prompt_selected, model_selected_tab,
37
- prompt_selected_tab, width, height, merge_steps, seed_set):
38
 
39
  # The gradio UI only supports one image at a time.
40
  if model_selected_tab==0:
@@ -80,6 +81,7 @@ def process(selected_template_images, custom_image, prompt,
80
  negative_prompt=negative_prompt,
81
  num_images_per_prompt=1,
82
  num_inference_steps=num_steps,
 
83
  start_merge_step=merge_steps,
84
  generator=generator,
85
  ).images[0]
@@ -135,9 +137,17 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
135
  prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
136
  for i, tab in enumerate(prompt_selected_tabs):
137
  tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
138
-
 
 
 
 
 
 
 
 
139
  width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
140
- height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
141
  width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
142
  height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
143
  merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
@@ -153,6 +163,6 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
153
  - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
154
  ''')
155
  btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
156
- model_selected_tab, prompt_selected_tab, width, height, merge_steps, seed_set], outputs=out)
157
 
158
  demo.launch(server_name='0.0.0.0', ssl_verify=False)
 
22
  pipe = ConsistentIDPipeline.from_pretrained(
23
  args.base_model_path,
24
  torch_dtype=torch.float16,
25
+ )
26
 
27
  ### Load consistentID_model checkpoint
28
  pipe.load_ConsistentID_model(
 
30
  bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
31
  )
32
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
+ pipe = pipe.to(device, torch.float16)
34
 
35
  @spaces.GPU
36
  def process(selected_template_images, custom_image, prompt,
37
  negative_prompt, prompt_selected, model_selected_tab,
38
+ prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set):
39
 
40
  # The gradio UI only supports one image at a time.
41
  if model_selected_tab==0:
 
81
  negative_prompt=negative_prompt,
82
  num_images_per_prompt=1,
83
  num_inference_steps=num_steps,
84
+ guidance_scale=guidance_scale,
85
  start_merge_step=merge_steps,
86
  generator=generator,
87
  ).images[0]
 
137
  prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
138
  for i, tab in enumerate(prompt_selected_tabs):
139
  tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
140
+
141
+ guidance_scale = gr.Slider(
142
+ label="Guidance scale",
143
+ minimum=1.0,
144
+ maximum=10.0,
145
+ step=1.0,
146
+ value=5.0,
147
+ )
148
+
149
  width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
150
+ height = gr.Slider(label="image height",minimum=256,maximum=768,value=512,step=8)
151
  width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
152
  height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
153
  merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
 
163
  - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
164
  ''')
165
  btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
166
+ model_selected_tab, prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set], outputs=out)
167
 
168
  demo.launch(server_name='0.0.0.0', ssl_verify=False)
lib/pipline_ConsistentID.py CHANGED
@@ -412,9 +412,6 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
412
 
413
  # 6. Get the update text embedding
414
  # parsed_image_parts2: the facial areas of the input image
415
- # text_local_id_embeds: [1, 77, 768]
416
- # text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
417
- # to text_global_id_embeds on the rest 73 tokens.
418
  # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
419
  # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
420
  # parsed_image_parts2: [1, 5, 3, 224, 224]
@@ -424,6 +421,8 @@ class ConsistentIDPipeline(StableDiffusionPipeline):
424
  calc_uncond=calc_uncond)
425
 
426
  # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
 
 
427
  text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
428
  text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
429
 
 
412
 
413
  # 6. Get the update text embedding
414
  # parsed_image_parts2: the facial areas of the input image
 
 
 
415
  # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
416
  # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
417
  # parsed_image_parts2: [1, 5, 3, 224, 224]
 
421
  calc_uncond=calc_uncond)
422
 
423
  # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
424
+ # text_local_id_embeds: [1, 77, 768], only differs with text_embeds on 4 ID embeddings, and is identical
425
+ # to text_embeds on the rest 73 tokens.
426
  text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
427
  text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
428