1inkusFace commited on
Commit
1292897
·
verified ·
1 Parent(s): 6502c3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -40,6 +40,8 @@ os.putenv("TOKENIZERS_PARALLELISM","False")
40
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
41
  base_model_id = "hunyuanvideo-community/HunyuanVideo"
42
 
 
 
43
  def init_predictor():
44
  global pipe
45
  text_encoder = LlamaModel.from_pretrained(
@@ -59,7 +61,7 @@ def init_predictor():
59
  text_encoder=text_encoder,
60
  torch_dtype=torch.bfloat16,
61
  ).to("cpu")
62
- pipe.to(torch.device('cuda'))
63
 
64
  @spaces.GPU(duration=60)
65
  def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
@@ -69,7 +71,7 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
69
  prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
70
  prompt=prompt, prompt_2=prompt, device=device
71
  )
72
- pipe.scheduler.set_timesteps(num_inference_steps, device=torch.device('cuda'))
73
  timesteps = pipe.scheduler.timesteps
74
  all_timesteps_cpu = timesteps.cpu()
75
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
 
40
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
41
  base_model_id = "hunyuanvideo-community/HunyuanVideo"
42
 
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+
45
  def init_predictor():
46
  global pipe
47
  text_encoder = LlamaModel.from_pretrained(
 
61
  text_encoder=text_encoder,
62
  torch_dtype=torch.bfloat16,
63
  ).to("cpu")
64
+ pipe.to(device)
65
 
66
  @spaces.GPU(duration=60)
67
  def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
 
71
  prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
72
  prompt=prompt, prompt_2=prompt, device=device
73
  )
74
+ pipe.scheduler.set_timesteps(num_inference_steps, device=device)
75
  timesteps = pipe.scheduler.timesteps
76
  all_timesteps_cpu = timesteps.cpu()
77
  timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)