Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -151,6 +151,10 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
|
|
151 |
size = state["width"]
|
152 |
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
153 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
|
|
|
|
|
|
|
|
|
154 |
if segment==9:
|
155 |
pipe.trasformer.to('cpu')
|
156 |
torch.cuda.empty_cache()
|
@@ -167,10 +171,6 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
|
|
167 |
return video_out_file, seed
|
168 |
else:
|
169 |
segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
|
170 |
-
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
171 |
-
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
172 |
-
prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
|
173 |
-
image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
|
174 |
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
175 |
for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
|
176 |
latents = latents.to(transformer_dtype)
|
|
|
151 |
size = state["width"]
|
152 |
pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
|
153 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
|
154 |
+
prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
155 |
+
pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
|
156 |
+
prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
|
157 |
+
image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
|
158 |
if segment==9:
|
159 |
pipe.trasformer.to('cpu')
|
160 |
torch.cuda.empty_cache()
|
|
|
171 |
return video_out_file, seed
|
172 |
else:
|
173 |
segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
|
|
|
|
|
|
|
|
|
174 |
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
|
175 |
for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
|
176 |
latents = latents.to(transformer_dtype)
|