1inkusFace commited on
Commit
5d8910b
·
verified ·
1 Parent(s): b9c3563

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -91,12 +91,13 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
91
  #)
92
  pipe.text_encoder.to("cuda")
93
  pipe.text_encoder_2.to("cuda")
94
- #with torch.no_grad():
95
- prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
96
- prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device
97
- )
98
  pipe.text_encoder.to("cpu")
99
  pipe.text_encoder_2.to("cpu")
 
100
  generator = torch.Generator(device='cuda').manual_seed(seed)
101
  transformer_dtype = pipe.transformer.dtype
102
  prompt_embeds = prompt_embeds.to(transformer_dtype)
@@ -133,6 +134,7 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
133
  )
134
  image_latents = image_latents.to(pipe.transformer.dtype)
135
  pipe.vae.to("cpu")
 
136
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
137
  else:
138
  state_file = f"rv_L_{segment-1}_{seed}.pt"
 
91
  #)
92
  pipe.text_encoder.to("cuda")
93
  pipe.text_encoder_2.to("cuda")
94
+ with torch.no_grad():
95
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_attention_mask, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
96
+ prompt=prompt, do_classifier_free_guidance=True, negative_prompt=negative_prompt, device=device
97
+ )
98
  pipe.text_encoder.to("cpu")
99
  pipe.text_encoder_2.to("cpu")
100
+ torch.cuda.empty_cache()
101
  generator = torch.Generator(device='cuda').manual_seed(seed)
102
  transformer_dtype = pipe.transformer.dtype
103
  prompt_embeds = prompt_embeds.to(transformer_dtype)
 
134
  )
135
  image_latents = image_latents.to(pipe.transformer.dtype)
136
  pipe.vae.to("cpu")
137
+ torch.cuda.empty_cache()
138
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
139
  else:
140
  state_file = f"rv_L_{segment-1}_{seed}.pt"