Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -40,6 +40,8 @@ os.putenv("TOKENIZERS_PARALLELISM","False")
|
|
40 |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
|
41 |
base_model_id = "hunyuanvideo-community/HunyuanVideo"
|
42 |
|
|
|
|
|
43 |
def init_predictor():
|
44 |
global pipe
|
45 |
text_encoder = LlamaModel.from_pretrained(
|
@@ -59,7 +61,7 @@ def init_predictor():
|
|
59 |
text_encoder=text_encoder,
|
60 |
torch_dtype=torch.bfloat16,
|
61 |
).to("cpu")
|
62 |
-
pipe.to(
|
63 |
|
64 |
@spaces.GPU(duration=60)
|
65 |
def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
|
@@ -69,7 +71,7 @@ def generate(segment, image, prompt, size, guidance_scale, num_inference_steps,
|
|
69 |
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
|
70 |
prompt=prompt, prompt_2=prompt, device=device
|
71 |
)
|
72 |
-
pipe.scheduler.set_timesteps(num_inference_steps, device=
|
73 |
timesteps = pipe.scheduler.timesteps
|
74 |
all_timesteps_cpu = timesteps.cpu()
|
75 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
|
|
|
40 |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
|
41 |
base_model_id = "hunyuanvideo-community/HunyuanVideo"
|
42 |
|
43 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
44 |
+
|
45 |
def init_predictor():
|
46 |
global pipe
|
47 |
text_encoder = LlamaModel.from_pretrained(
|
|
|
61 |
text_encoder=text_encoder,
|
62 |
torch_dtype=torch.bfloat16,
|
63 |
).to("cpu")
|
64 |
+
pipe.to(device)
|
65 |
|
66 |
@spaces.GPU(duration=60)
|
67 |
def generate(segment, image, prompt, size, guidance_scale, num_inference_steps, frames, seed, progress=gr.Progress(track_tqdm=True) ):
|
|
|
71 |
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
|
72 |
prompt=prompt, prompt_2=prompt, device=device
|
73 |
)
|
74 |
+
pipe.scheduler.set_timesteps(num_inference_steps, device=device)
|
75 |
timesteps = pipe.scheduler.timesteps
|
76 |
all_timesteps_cpu = timesteps.cpu()
|
77 |
timesteps_split_np = np.array_split(all_timesteps_cpu.numpy(), 8)
|