1inkusFace commited on
Commit
138cc18
·
verified ·
1 Parent(s): dc7a5fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -151,13 +151,6 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
151
  size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
154
- segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
155
- prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
156
- pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
157
- prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
158
- image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
159
- guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
160
- #pipe.transformer.to("cuda")
161
  if segment==9:
162
  pipe.trasformer.to('cpu')
163
  torch.cuda.empty_cache()
@@ -173,6 +166,12 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
173
  export_to_video(output, video_out_file, fps=24)
174
  return video_out_file, seed
175
  else:
 
 
 
 
 
 
176
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
177
  latents = latents.to(transformer_dtype)
178
  latent_model_input = torch.cat([latents] * 2)
 
151
  size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
 
 
 
 
 
 
 
154
  if segment==9:
155
  pipe.trasformer.to('cpu')
156
  torch.cuda.empty_cache()
 
166
  export_to_video(output, video_out_file, fps=24)
167
  return video_out_file, seed
168
  else:
169
+ segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
170
+ prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
171
+ pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
172
+ prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
173
+ image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
174
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
175
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
176
  latents = latents.to(transformer_dtype)
177
  latent_model_input = torch.cat([latents] * 2)