1inkusFace commited on
Commit
dc7a5fd
·
verified ·
1 Parent(s): 016261b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -138,21 +138,34 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
138
  torch.cuda.empty_cache()
139
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
140
  else:
141
- state_file = f"SkyReel_{segment-1}_{seed}.pt"
 
142
  transformer_dtype = pipe.transformer.dtype
 
143
  state = torch.load(state_file, weights_only=False)
144
  generator = torch.Generator(device='cuda').manual_seed(seed)
145
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
 
146
  all_timesteps_cpu = state["all_timesteps"]
 
 
147
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
 
 
 
 
 
 
 
 
148
  if segment==9:
149
  pipe.trasformer.to('cpu')
150
  torch.cuda.empty_cache()
151
  pipe.vae.to("cuda")
152
  latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
153
- with torch.no_grad():
154
- video = pipe.vae.decode(latents, return_dict=False)[0]
155
- video = pipe.video_processor.postprocess_video(video)
156
  # return HunyuanVideoPipelineOutput(frames=video)
157
  save_dir = f"./"
158
  video_out_file = f"{save_dir}/{seed}.mp4"
@@ -160,18 +173,6 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
160
  export_to_video(output, video_out_file, fps=24)
161
  return video_out_file, seed
162
  else:
163
- pipe.vae.to("cpu")
164
- torch.cuda.empty_cache()
165
- guidance_scale = state["guidance_scale"]
166
- size = state["height"]
167
- size = state["width"]
168
- timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
169
- segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
170
- prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
171
- pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
172
- prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
173
- image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
174
- guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
175
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
176
  latents = latents.to(transformer_dtype)
177
  latent_model_input = torch.cat([latents] * 2)
 
138
  torch.cuda.empty_cache()
139
  guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
140
  else:
141
+ pipe.vae.to("cpu")
142
+ torch.cuda.empty_cache()
143
  transformer_dtype = pipe.transformer.dtype
144
+ state_file = f"SkyReel_{segment-1}_{seed}.pt"
145
  state = torch.load(state_file, weights_only=False)
146
  generator = torch.Generator(device='cuda').manual_seed(seed)
147
  latents = state["intermediate_latents"].to("cuda", dtype=torch.bfloat16)
148
+ guidance_scale = state["guidance_scale"]
149
  all_timesteps_cpu = state["all_timesteps"]
150
+ size = state["height"]
151
+ size = state["width"]
152
  pipe.scheduler.set_timesteps(len(all_timesteps_cpu), device=device)
153
+ timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
154
+ segment_timesteps = torch.from_numpy(timesteps_split_np[segment - 1]).to("cuda")
155
+ prompt_embeds = state["prompt_embeds"].to("cuda", dtype=torch.bfloat16)
156
+ pooled_prompt_embeds = state["pooled_prompt_embeds"].to("cuda", dtype=torch.bfloat16)
157
+ prompt_attention_mask = state["prompt_attention_mask"].to("cuda", dtype=torch.bfloat16)
158
+ image_latents = state["image_latents"].to("cuda", dtype=torch.bfloat16)
159
+ guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
160
+ #pipe.transformer.to("cuda")
161
  if segment==9:
162
  pipe.trasformer.to('cpu')
163
  torch.cuda.empty_cache()
164
  pipe.vae.to("cuda")
165
  latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
166
+ #with torch.no_grad():
167
+ video = pipe.vae.decode(latents, return_dict=False)[0]
168
+ video = pipe.video_processor.postprocess_video(video)
169
  # return HunyuanVideoPipelineOutput(frames=video)
170
  save_dir = f"./"
171
  video_out_file = f"{save_dir}/{seed}.mp4"
 
173
  export_to_video(output, video_out_file, fps=24)
174
  return video_out_file, seed
175
  else:
 
 
 
 
 
 
 
 
 
 
 
 
176
  for i, t in enumerate(pipe.progress_bar(segment_timesteps)):
177
  latents = latents.to(transformer_dtype)
178
  latent_model_input = torch.cat([latents] * 2)