Xinsheng Wang commited on
Commit
882c25f
·
unverified ·
2 Parent(s): 714c642 72f27d6

Merge pull request #54 from xcv58/macos

Browse files

Add support for macOS Metal Performance Shaders (MPS) and CPU fallback

Files changed (2) hide show
  1. cli/inference.py +13 -1
  2. webui.py +18 -2
cli/inference.py CHANGED
@@ -20,6 +20,7 @@ import torch
20
  import soundfile as sf
21
  import logging
22
  from datetime import datetime
 
23
 
24
  from cli.SparkTTS import SparkTTS
25
 
@@ -69,7 +70,18 @@ def run_tts(args):
69
  os.makedirs(args.save_dir, exist_ok=True)
70
 
71
  # Convert device argument to torch.device
72
- device = torch.device(f"cuda:{args.device}")
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Initialize the model
75
  model = SparkTTS(args.model_dir, device)
 
20
  import soundfile as sf
21
  import logging
22
  from datetime import datetime
23
+ import platform
24
 
25
  from cli.SparkTTS import SparkTTS
26
 
 
70
  os.makedirs(args.save_dir, exist_ok=True)
71
 
72
  # Convert device argument to torch.device
73
+ if platform.system() == "Darwin" and torch.backends.mps.is_available():
74
+ # macOS with MPS support (Apple Silicon)
75
+ device = torch.device(f"mps:{args.device}")
76
+ logging.info(f"Using MPS device: {device}")
77
+ elif torch.cuda.is_available():
78
+ # System with CUDA support
79
+ device = torch.device(f"cuda:{args.device}")
80
+ logging.info(f"Using CUDA device: {device}")
81
+ else:
82
+ # Fall back to CPU
83
+ device = torch.device("cpu")
84
+ logging.info("GPU acceleration not available, using CPU")
85
 
86
  # Initialize the model
87
  model = SparkTTS(args.model_dir, device)
webui.py CHANGED
@@ -19,6 +19,8 @@ import soundfile as sf
19
  import logging
20
  import argparse
21
  import gradio as gr
 
 
22
  from datetime import datetime
23
  from cli.SparkTTS import SparkTTS
24
  from sparktts.utils.token_parser import LEVELS_MAP_UI
@@ -27,7 +29,21 @@ from sparktts.utils.token_parser import LEVELS_MAP_UI
27
  def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
28
  """Load the model once at the beginning."""
29
  logging.info(f"Loading model from: {model_dir}")
30
- device = torch.device(f"cuda:{device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model = SparkTTS(model_dir, device)
32
  return model
33
 
@@ -76,7 +92,7 @@ def run_tts(
76
 
77
 
78
  def build_ui(model_dir, device=0):
79
-
80
  # Initialize model
81
  model = initialize_model(model_dir, device=device)
82
 
 
19
  import logging
20
  import argparse
21
  import gradio as gr
22
+ import platform
23
+
24
  from datetime import datetime
25
  from cli.SparkTTS import SparkTTS
26
  from sparktts.utils.token_parser import LEVELS_MAP_UI
 
29
  def initialize_model(model_dir="pretrained_models/Spark-TTS-0.5B", device=0):
30
  """Load the model once at the beginning."""
31
  logging.info(f"Loading model from: {model_dir}")
32
+
33
+ # Determine appropriate device based on platform and availability
34
+ if platform.system() == "Darwin":
35
+ # macOS with MPS support (Apple Silicon)
36
+ device = torch.device(f"mps:{device}")
37
+ logging.info(f"Using MPS device: {device}")
38
+ elif torch.cuda.is_available():
39
+ # System with CUDA support
40
+ device = torch.device(f"cuda:{device}")
41
+ logging.info(f"Using CUDA device: {device}")
42
+ else:
43
+ # Fall back to CPU
44
+ device = torch.device("cpu")
45
+ logging.info("GPU acceleration not available, using CPU")
46
+
47
  model = SparkTTS(model_dir, device)
48
  return model
49
 
 
92
 
93
 
94
  def build_ui(model_dir, device=0):
95
+
96
  # Initialize model
97
  model = initialize_model(model_dir, device=device)
98