1inkusFace commited on
Commit
016261b
·
verified ·
1 Parent(s): 217d984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -138,34 +138,21 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
138
  torch.cuda.empty_cache()
139
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
140
  else:
141
- pipe.vae.to("cpu")
142
- torch.cuda.empty_cache()
143
- transformer_dtype = pipe.transformer.dtype
144
  state_file = f"SkyReel_{segment-1}_{seed}.pt"
 
145
  state = torch.load(state_file, weights_only=False)
146
  generator = torch.Generator(device='cuda').manual_seed(seed)
147
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
148
- guidance_scale = state["guidance_scale"]
149
  all_timesteps_cpu = state["all_timesteps"]
150
- size = state["height"]
151
- size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
- timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
154
- segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
155
- prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
156
- pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
157
- prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
158
- image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
159
- guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
160
- #pipe.transformer.to("cuda")
161
  if segment==9:
162
  pipe.trasformer.to('cpu')
163
  torch.cuda.empty_cache()
164
  pipe.vae.to("cuda")
165
  latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
166
- #with torch.no_grad():
167
- video = pipe.vae.decode(latents, return_dict=False)[0]
168
- video = pipe.video_processor.postprocess_video(video)
169
  # return HunyuanVideoPipelineOutput(frames=video)
170
  save_dir = f"./"
171
  video_out_file = f"{save_dir}/{seed}.mp4"
@@ -173,6 +160,18 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
173
  export_to_video(output, video_out_file, fps=24)
174
  return video_out_file, seed
175
  else:
 
 
 
 
 
 
 
 
 
 
 
 
176
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
177
  latents = latents.to(transformer_dtype)
178
  latent_model_input = torch.cat([latents] * 2)
 
138
  torch.cuda.empty_cache()
139
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
140
  else:
 
 
 
141
  state_file = f"SkyReel_{segment-1}_{seed}.pt"
142
+ transformer_dtype = pipe.transformer.dtype
143
  state = torch.load(state_file, weights_only=False)
144
  generator = torch.Generator(device='cuda').manual_seed(seed)
145
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
 
146
  all_timesteps_cpu = state["all_timesteps"]
 
 
147
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
 
 
 
 
 
 
 
 
148
  if segment==9:
149
  pipe.trasformer.to('cpu')
150
  torch.cuda.empty_cache()
151
  pipe.vae.to("cuda")
152
  latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
153
+ with torch.no_grad():
154
+ video = pipe.vae.decode(latents, return_dict=False)[0]
155
+ video = pipe.video_processor.postprocess_video(video)
156
  # return HunyuanVideoPipelineOutput(frames=video)
157
  save_dir = f"./"
158
  video_out_file = f"{save_dir}/{seed}.mp4"
 
160
  export_to_video(output, video_out_file, fps=24)
161
  return video_out_file, seed
162
  else:
163
+ pipe.vae.to("cpu")
164
+ torch.cuda.empty_cache()
165
+ guidance_scale = state["guidance_scale"]
166
+ size = state["height"]
167
+ size = state["width"]
168
+ timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
169
+ segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
170
+ prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
171
+ pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
172
+ prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
173
+ image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
174
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
175
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
176
  latents = latents.to(transformer_dtype)
177
  latent_model_input = torch.cat([latents] * 2)