1inkusFace commited on
Commit
baec09d
·
verified ·
1 Parent(s): 138cc18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -151,6 +151,10 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
151
  size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
 
 
 
 
154
  if segment==9:
155
  pipe.trasformer.to('cpu')
156
  torch.cuda.empty_cache()
@@ -167,10 +171,6 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
167
  return video_out_file, seed
168
  else:
169
  segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
170
- prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
171
- pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
172
- prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
173
- image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
174
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
175
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
176
  latents = latents.to(transformer_dtype)
 
151
  size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
154
+ prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
155
+ pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
156
+ prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
157
+ image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
158
  if segment==9:
159
  pipe.trasformer.to('cpu')
160
  torch.cuda.empty_cache()
 
171
  return video_out_file, seed
172
  else:
173
  segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
 
 
 
 
174
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
175
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
176
  latents = latents.to(transformer_dtype)