xcv58 commited on
Commit
7d742ba
·
unverified ·
1 Parent(s): 714c642

Add support for macOS Metal Performance Shaders (MPS) and CPU fallback

Browse files
Files changed (1) hide show
  1. cli/inference.py +13 -1
cli/inference.py CHANGED
@@ -20,6 +20,7 @@ import torch
20
  import soundfile as sf
21
  import logging
22
  from datetime import datetime
 
23
 
24
  from cli.SparkTTS import SparkTTS
25
 
@@ -69,7 +70,18 @@ def run_tts(args):
69
  os.makedirs(args.save_dir, exist_ok=True)
70
 
71
  # Convert device argument to torch.device
72
- device = torch.device(f"cuda:{args.device}")
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Initialize the model
75
  model = SparkTTS(args.model_dir, device)
 
20
  import soundfile as sf
21
  import logging
22
  from datetime import datetime
23
+ import platform
24
 
25
  from cli.SparkTTS import SparkTTS
26
 
 
70
  os.makedirs(args.save_dir, exist_ok=True)
71
 
72
  # Convert device argument to torch.device
73
+ if platform.system() == "Darwin" and torch.backends.mps.is_available():
74
+ # macOS with MPS support (Apple Silicon)
75
+ device = torch.device(f"mps:{args.device}")
76
+ logging.info(f"Using MPS device: {device}")
77
+ elif torch.cuda.is_available():
78
+ # System with CUDA support
79
+ device = torch.device(f"cuda:{args.device}")
80
+ logging.info(f"Using CUDA device: {device}")
81
+ else:
82
+ # Fall back to CPU
83
+ device = torch.device("cpu")
84
+ logging.info("GPU acceleration not available, using CPU")
85
 
86
  # Initialize the model
87
  model = SparkTTS(args.model_dir, device)