chenjgtea commited on
Commit
c23cee5
·
1 Parent(s): d8d4be7

torch 更新

Browse files
Files changed (3) hide show
  1. tool/__init__.py +2 -1
  2. tool/gpu.py +48 -0
  3. web/app.py +5 -1
tool/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  from .av import load_audio
2
  from .pcm import pcm_arr_to_mp3_view
3
  from .np import float_to_int16
4
- from .ctx import TorchSeedContext
 
 
1
  from .av import load_audio
2
  from .pcm import pcm_arr_to_mp3_view
3
  from .np import float_to_int16
4
+ from .ctx import TorchSeedContext
5
+ from .gpu import select_device
tool/gpu.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os, sys
3
+ import spaces
4
+
5
+ if sys.platform == "darwin":
6
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
7
+ now_dir = os.getcwd()
8
+ sys.path.append(now_dir)
9
+ from tool.logger import get_logger
10
+
11
+ logger = get_logger("gpu")
12
+
13
+
14
+ def select_device(min_memory=2047, experimental=False):
15
+ if torch.cuda.is_available():
16
+ selected_gpu = 0
17
+ max_free_memory = -1
18
+ for i in range(torch.cuda.device_count()):
19
+ props = torch.cuda.get_device_properties(i)
20
+ free_memory = props.total_memory - torch.cuda.memory_reserved(i)
21
+ if max_free_memory < free_memory:
22
+ selected_gpu = i
23
+ max_free_memory = free_memory
24
+ free_memory_mb = max_free_memory / (1024 * 1024)
25
+ if free_memory_mb < min_memory:
26
+ logger.get_logger().warning(
27
+ f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
28
+ )
29
+ device = torch.device("cpu")
30
+ else:
31
+ device = torch.device(f"cuda:{selected_gpu}")
32
+ elif torch.backends.mps.is_available():
33
+ """
34
+ Currently MPS is slower than CPU while needs more memory and core utility,
35
+ so only enable this for experimental use.
36
+ """
37
+ if experimental:
38
+ # For Apple M1/M2 chips with Metal Performance Shaders
39
+ logger.get_logger().warn("experimantal: found apple GPU, using MPS.")
40
+ device = torch.device("mps")
41
+ else:
42
+ logger.get_logger().info("found Apple GPU, but use CPU.")
43
+ device = torch.device("cpu")
44
+ else:
45
+ logger.get_logger().warning("no GPU found, use CPU instead")
46
+ device = torch.device("cpu")
47
+
48
+ return device
web/app.py CHANGED
@@ -9,6 +9,7 @@ sys.path.append(now_dir)
9
  from tool.logger import get_logger
10
  from tool.func import *
11
  from tool.np import *
 
12
  import ChatTTS
13
  import argparse
14
  import torch._dynamo
@@ -33,8 +34,11 @@ def init_chat(args):
33
  if MODEL == "HF":
34
  source = "huggingface"
35
 
 
36
 
37
- if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts",device=torch.device(f"cuda:0")):
 
 
38
  print("Models loaded successfully.")
39
  logger.info("Models loaded successfully.")
40
  else:
 
9
  from tool.logger import get_logger
10
  from tool.func import *
11
  from tool.np import *
12
+ from tool.gpu import select_device
13
  import ChatTTS
14
  import argparse
15
  import torch._dynamo
 
34
  if MODEL == "HF":
35
  source = "huggingface"
36
 
37
+ device=select_device()
38
 
39
+ logger.info("loading ChatTTS device :" + str(device))
40
+
41
+ if chat.load(source=source, custom_path="D:\\chenjgspace\\ai-model\\chattts", device=device):
42
  print("Models loaded successfully.")
43
  logger.info("Models loaded successfully.")
44
  else: