Linoy Tsaban commited on
Commit
76afba1
·
1 Parent(s): c71b83b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -55
app.py CHANGED
@@ -34,61 +34,76 @@ def caption_image(input_image):
34
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
  return generated_caption, generated_caption
36
 
37
- def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
38
-
39
- latnets = wts.value[-1].expand(1, -1, -1, -1)
40
- img = pipe(prompt=prompt_tar,
41
- init_latents=latnets,
42
- guidance_scale = cfg_scale_tar,
43
- # num_images_per_prompt=1,
44
- # num_inference_steps=steps,
45
- # use_ddpm=True,
46
- # wts=wts.value,
47
- zs=zs.value).images[0]
 
48
  return img
49
 
50
- def reconstruct(tar_prompt,
51
- image_caption,
52
- tar_cfg_scale,
53
- skip,
54
- wts, zs,
55
- do_reconstruction,
56
- reconstruction,
57
- reconstruct_button
58
- ):
59
 
 
 
 
 
 
 
 
 
 
 
 
60
  if reconstruct_button == "Hide Reconstruction":
61
- return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=False), do_reconstruction, "Show Reconstruction"
 
 
 
 
 
 
62
 
63
  else:
64
- if do_reconstruction:
65
- if image_caption.lower() == tar_prompt.lower(): # if image caption was not changed, run actual reconstruction
66
- tar_prompt = ""
67
- latnets = wts.value[-1].expand(1, -1, -1, -1)
68
- reconstruction_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
69
- reconstruction = gr.State(value=reconstruction_img)
70
- do_reconstruction = False
71
- return reconstruction.value, reconstruction, ddpm_edited_image.update(visible=True), do_reconstruction, "Hide Reconstruction"
72
-
73
-
 
 
 
 
 
 
 
74
 
75
 
76
  def load_and_invert(
77
- input_image,
78
- do_inversion,
79
- seed, randomize_seed,
80
- wts, zs,
81
- src_prompt ="",
82
- # tar_prompt="",
83
- steps=30,
84
- src_cfg_scale = 3.5,
85
- skip=15,
86
- tar_cfg_scale=15,
87
- progress=gr.Progress(track_tqdm=True)
88
-
 
89
  ):
90
-
91
-
92
  # x0 = load_512(input_image, device=device).to(torch.float16)
93
 
94
  if do_inversion or randomize_seed:
@@ -96,16 +111,14 @@ def load_and_invert(
96
  seed = randomize_seed_fn()
97
  seed_everything(seed)
98
  # invert and retrieve noise maps and latent
99
- zs_tensor, wts_tensor = pipe.invert(
100
- image_path = input_image,
101
- source_prompt =src_prompt,
102
- source_guidance_scale= src_cfg_scale,
103
- num_inversion_steps = steps,
104
- skip = skip,
105
- eta = 1.0,
106
- )
107
- wts = gr.State(value=wts_tensor)
108
- zs = gr.State(value=zs_tensor)
109
  do_inversion = False
110
 
111
  return wts, zs, do_inversion, inversion_progress.update(visible=False)
@@ -171,6 +184,8 @@ def edit(input_image,
171
  edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
172
  edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
173
  edit_threshold=[threshold_1, threshold_2, threshold_3],
 
 
174
  eta=1,
175
  use_cross_attn_mask=use_cross_attn_mask,
176
  use_intersect_mask=use_intersect_mask
 
34
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
  return generated_caption, generated_caption
36
 
37
+ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
38
+ latents = wts[-1].expand(1, -1, -1, -1)
39
+ img = pipe(
40
+ prompt=prompt_tar,
41
+ init_latents=latents,
42
+ guidance_scale=cfg_scale_tar,
43
+ # num_images_per_prompt=1,
44
+ # num_inference_steps=steps,
45
+ # use_ddpm=True,
46
+ # wts=wts.value,
47
+ zs=zs,
48
+ ).images[0]
49
  return img
50
 
 
 
 
 
 
 
 
 
 
51
 
52
+ def reconstruct(
53
+ tar_prompt,
54
+ image_caption,
55
+ tar_cfg_scale,
56
+ skip,
57
+ wts,
58
+ zs,
59
+ do_reconstruction,
60
+ reconstruction,
61
+ reconstruct_button,
62
+ ):
63
  if reconstruct_button == "Hide Reconstruction":
64
+ return (
65
+ reconstruction,
66
+ reconstruction,
67
+ ddpm_edited_image.update(visible=False),
68
+ do_reconstruction,
69
+ "Show Reconstruction",
70
+ )
71
 
72
  else:
73
+ if do_reconstruction:
74
+ if (
75
+ image_caption.lower() == tar_prompt.lower()
76
+ ): # if image caption was not changed, run actual reconstruction
77
+ tar_prompt = ""
78
+ latents = wts[-1].expand(1, -1, -1, -1)
79
+ reconstruction = sample(
80
+ zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
81
+ )
82
+ do_reconstruction = False
83
+ return (
84
+ reconstruction,
85
+ reconstruction,
86
+ ddpm_edited_image.update(visible=True),
87
+ do_reconstruction,
88
+ "Hide Reconstruction",
89
+ )
90
 
91
 
92
  def load_and_invert(
93
+ input_image,
94
+ do_inversion,
95
+ seed,
96
+ randomize_seed,
97
+ wts,
98
+ zs,
99
+ src_prompt="",
100
+ # tar_prompt="",
101
+ steps=30,
102
+ src_cfg_scale=3.5,
103
+ skip=15,
104
+ tar_cfg_scale=15,
105
+ progress=gr.Progress(track_tqdm=True),
106
  ):
 
 
107
  # x0 = load_512(input_image, device=device).to(torch.float16)
108
 
109
  if do_inversion or randomize_seed:
 
111
  seed = randomize_seed_fn()
112
  seed_everything(seed)
113
  # invert and retrieve noise maps and latent
114
+ zs, wts = pipe.invert(
115
+ image_path=input_image,
116
+ source_prompt=src_prompt,
117
+ source_guidance_scale=src_cfg_scale,
118
+ num_inversion_steps=steps,
119
+ skip=skip,
120
+ eta=1.0,
121
+ )
 
 
122
  do_inversion = False
123
 
124
  return wts, zs, do_inversion, inversion_progress.update(visible=False)
 
184
  edit_warmup_steps=[warmup_1, warmup_2, warmup_3,],
185
  edit_guidance_scale=[guidnace_scale_1,guidnace_scale_2,guidnace_scale_3],
186
  edit_threshold=[threshold_1, threshold_2, threshold_3],
187
+ edit_momentum_scale=0,
188
+ edit_mom_beta=0.6,
189
  eta=1,
190
  use_cross_attn_mask=use_cross_attn_mask,
191
  use_intersect_mask=use_intersect_mask