Linoy Tsaban commited on
Commit
8f7289c
·
1 Parent(s): 998e5bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -21
app.py CHANGED
@@ -132,7 +132,23 @@ def edit(input_image,
132
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
133
  threshold_1, threshold_2, threshold_3,
134
  do_reconstruction,
135
- reconstruction):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
138
  editing_args = dict(
@@ -151,7 +167,7 @@ def edit(input_image,
151
  num_inference_steps=steps,
152
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
153
 
154
- return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction
155
 
156
  else: # if sega concepts were not added, performs regular ddpm sampling
157
 
@@ -159,9 +175,9 @@ def edit(input_image,
159
  pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
160
  reconstruction = gr.State(value=pure_ddpm_img)
161
  do_reconstruction = False
162
- return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
163
 
164
- return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction
165
 
166
 
167
  def randomize_seed_fn(seed, randomize_seed):
@@ -635,21 +651,7 @@ with gr.Blocks(css="style.css") as demo:
635
  #add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
636
  # outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
637
 
638
- run_button.click(fn = update_inversion_progress_visibility, inputs =[input_image,do_inversion], outputs=[inversion_progress],queue=False).then(
639
- fn=load_and_invert,
640
- inputs=[input_image,
641
- do_inversion,
642
- seed, randomize_seed,
643
- wts, zs,
644
- src_prompt,
645
- tar_prompt,
646
- steps,
647
- src_cfg_scale,
648
- skip,
649
- tar_cfg_scale
650
- ],
651
- outputs=[wts, zs, do_inversion, inversion_progress],
652
- ).success(
653
  fn=edit,
654
  inputs=[input_image,
655
  wts, zs,
@@ -661,10 +663,16 @@ with gr.Blocks(css="style.css") as demo:
661
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
662
  warmup_1, warmup_2, warmup_3,
663
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
664
- threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction
 
 
 
 
 
 
665
 
666
  ],
667
- outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction])
668
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
669
 
670
 
 
132
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
133
  threshold_1, threshold_2, threshold_3,
134
  do_reconstruction,
135
+ reconstruction,
136
+
137
+ # for inversion in case it needs to be re computed (and avoid delay):
138
+ do_inversion,
139
+ seed,
140
+ randomize_seed,
141
+ src_prompt,
142
+ src_cfg_scale):
143
+
144
+ if do_inversion or randomize_seed:
145
+ x0 = load_512(input_image, device=device)
146
+ # invert and retrieve noise maps and latent
147
+ zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
148
+ wts = gr.State(value=wts_tensor)
149
+ zs = gr.State(value=zs_tensor)
150
+ do_inversion = False
151
+
152
 
153
  if edit_concept_1 != "" or edit_concept_2 != "" or edit_concept_3 != "":
154
  editing_args = dict(
 
167
  num_inference_steps=steps,
168
  use_ddpm=True, wts=wts.value, zs=zs.value[skip:], **editing_args)
169
 
170
+ return sega_out.images[0], reconstruct_button.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion
171
 
172
  else: # if sega concepts were not added, performs regular ddpm sampling
173
 
 
175
  pure_ddpm_img = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
176
  reconstruction = gr.State(value=pure_ddpm_img)
177
  do_reconstruction = False
178
+ return pure_ddpm_img, reconstruct_button.update(visible=False), do_reconstruction, reconstruction wts, zs, do_inversion
179
 
180
+ return reconstruction.value, reconstruct_button.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion
181
 
182
 
183
  def randomize_seed_fn(seed, randomize_seed):
 
651
  #add_concept_button.click(fn = update_display_concept, inputs=sega_concepts_counter,
652
  # outputs= [row2, row2_advanced, row3, row3_advanced, add_concept_button, sega_concepts_counter], queue = False)
653
 
654
+ run_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  fn=edit,
656
  inputs=[input_image,
657
  wts, zs,
 
663
  guidnace_scale_1,guidnace_scale_2,guidnace_scale_3,
664
  warmup_1, warmup_2, warmup_3,
665
  neg_guidance_1, neg_guidance_2, neg_guidance_3,
666
+ threshold_1, threshold_2, threshold_3, do_reconstruction, reconstruction,
667
+ do_inversion,
668
+ seed,
669
+ randomize_seed,
670
+ src_prompt,
671
+ src_cfg_scale
672
+
673
 
674
  ],
675
+ outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion])
676
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
677
 
678