seokochin commited on
Commit
33df378
·
verified ·
1 Parent(s): 315ca65

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +17 -17
generate.py CHANGED
@@ -1,36 +1,36 @@
1
  import argparse
2
- import torch
3
  import subprocess
4
  import os
 
5
 
6
- # Define Arguments
7
  parser = argparse.ArgumentParser()
8
- parser.add_argument("--task", type=str, default="t2v-1.3B")
9
  parser.add_argument("--size", type=str, default="832*480")
10
  parser.add_argument("--frame_num", type=int, default=60)
11
- parser.add_argument("--sample_steps", type=int, default=30)
12
- parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-1.3B")
 
 
13
  parser.add_argument("--prompt", type=str, required=True)
14
  args = parser.parse_args()
15
 
16
- # Check GPU Availability
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Using device: {device}")
 
19
 
20
- # Run WAN 2.1 Inference
21
- command = f"python run_model.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --prompt \"{args.prompt}\" --device {device}"
22
-
23
- subprocess.run(command, shell=True)
24
 
25
  process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
26
  stdout, stderr = process.communicate()
27
 
28
- print("Output:", stdout.decode())
29
- print("Error:", stderr.decode())
30
-
31
 
32
- # Save output
33
  if os.path.exists("output.mp4"):
34
  print("✅ Video generated successfully: output.mp4")
35
  else:
36
- print("❌ Error generating video.")
 
1
  import argparse
 
2
  import subprocess
3
  import os
4
+ from huggingface_hub import snapshot_download
5
 
6
+ # Arguments
7
  parser = argparse.ArgumentParser()
8
+ parser.add_argument("--task", type=str, default="t2v-14B")
9
  parser.add_argument("--size", type=str, default="832*480")
10
  parser.add_argument("--frame_num", type=int, default=60)
11
+ parser.add_argument("--sample_steps", type=int, default=25)
12
+ parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-14B")
13
+ parser.add_argument("--offload_model", type=str, default="True")
14
+ parser.add_argument("--precision", type=str, default="bf16")
15
  parser.add_argument("--prompt", type=str, required=True)
16
  args = parser.parse_args()
17
 
18
+ # Download model if not available
19
+ if not os.path.exists(args.ckpt_dir):
20
+ print("🔄 Downloading WAN 2.1 model...")
21
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir=args.ckpt_dir)
22
 
23
+ # Run Model
24
+ command = f"python run_model.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --offload_model {args.offload_model} --precision {args.precision} --prompt \"{args.prompt}\""
 
 
25
 
26
  process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
27
  stdout, stderr = process.communicate()
28
 
29
+ print("🔹 Output:", stdout.decode())
30
+ print("🔺 Error:", stderr.decode())
 
31
 
32
+ # Check if video was created
33
  if os.path.exists("output.mp4"):
34
  print("✅ Video generated successfully: output.mp4")
35
  else:
36
+ print("❌ Error: Video file not found!")