Flux9665 commited on
Commit
4f19891
·
verified ·
1 Parent(s): 3abac7b

Update InferenceInterfaces/ToucanTTSInterface.py

Browse files
InferenceInterfaces/ToucanTTSInterface.py CHANGED
@@ -24,15 +24,19 @@ class ToucanTTSInterface(torch.nn.Module):
24
 
25
  def __init__(self,
26
  device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude.
27
- tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), # path to the ToucanTTS checkpoint or just a shorthand if run standalone
28
- vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), # path to the Vocoder checkpoint
29
  language="eng", # initial language of the model, can be changed later with the setter methods
30
  ):
31
  super().__init__()
32
  self.device = device
33
- if not tts_model_path.endswith(".pt"):
 
 
34
  # default to shorthand system
35
  tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt")
 
 
36
 
37
  ################################
38
  # build text to phone #
 
24
 
25
  def __init__(self,
26
  device="cpu", # device that everything computes on. If a cuda device is available, this can speed things up by an order of magnitude.
27
+ tts_model_path=None, # path to the ToucanTTS checkpoint or just a shorthand if run standalone
28
+ vocoder_model_path=None, # path to the Vocoder checkpoint
29
  language="eng", # initial language of the model, can be changed later with the setter methods
30
  ):
31
  super().__init__()
32
  self.device = device
33
+ if tts_model_path is None:
34
+ tts_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="ToucanTTS.pt")
35
+ elif not tts_model_path.endswith(".pt"):
36
  # default to shorthand system
37
  tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt")
38
+ if vocoder_model_path is None:
39
+ vocoder_model_path = hf_hub_download(repo_id="Flux9665/ToucanTTS", filename="Vocoder.pt")
40
 
41
  ################################
42
  # build text to phone #