CosyVoice commited on
Commit
1ab3186
1 Parent(s): 1d881df

revert trt TODO

Browse files
cosyvoice/cli/cosyvoice.py CHANGED
@@ -21,7 +21,7 @@ from cosyvoice.utils.file_utils import logging
21
 
22
  class CosyVoice:
23
 
24
- def __init__(self, model_dir, load_jit=True, load_trt=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
@@ -42,9 +42,6 @@ class CosyVoice:
42
  if load_jit:
43
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
  '{}/llm.llm.fp16.zip'.format(model_dir))
45
- if load_trt:
46
- # TODO
47
- self.model.load_trt()
48
  del configs
49
 
50
  def list_avaliable_spks(self):
 
21
 
22
  class CosyVoice:
23
 
24
+ def __init__(self, model_dir, load_jit=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
 
42
  if load_jit:
43
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
  '{}/llm.llm.fp16.zip'.format(model_dir))
 
 
 
45
  del configs
46
 
47
  def list_avaliable_spks(self):
cosyvoice/cli/model.py CHANGED
@@ -66,11 +66,6 @@ class CosyVoiceModel:
66
  llm_llm = torch.jit.load(llm_llm_model)
67
  self.llm.llm = llm_llm
68
 
69
- def load_trt(self):
70
- # TODO 你需要的TRT推理的准备
71
- self.flow.decoder.estimator = xxx
72
- self.flow.decoder.session = xxx
73
-
74
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
75
  with self.llm_context:
76
  for i in self.llm.inference(text=text.to(self.device),
@@ -126,7 +121,6 @@ class CosyVoiceModel:
126
  self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
127
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
128
  p.start()
129
- p.join()
130
  if stream is True:
131
  token_hop_len = self.token_min_hop_len
132
  while True:
@@ -147,7 +141,7 @@ class CosyVoiceModel:
147
  token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
148
  if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
149
  break
150
- # p.join()
151
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
152
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
153
  with self.flow_hift_context:
@@ -160,7 +154,7 @@ class CosyVoiceModel:
160
  yield {'tts_speech': this_tts_speech.cpu()}
161
  else:
162
  # deal with all tokens
163
- # p.join()
164
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
165
  with self.flow_hift_context:
166
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
 
66
  llm_llm = torch.jit.load(llm_llm_model)
67
  self.llm.llm = llm_llm
68
 
 
 
 
 
 
69
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
70
  with self.llm_context:
71
  for i in self.llm.inference(text=text.to(self.device),
 
121
  self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
122
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
123
  p.start()
 
124
  if stream is True:
125
  token_hop_len = self.token_min_hop_len
126
  while True:
 
141
  token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
142
  if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
143
  break
144
+ p.join()
145
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
146
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
147
  with self.flow_hift_context:
 
154
  yield {'tts_speech': this_tts_speech.cpu()}
155
  else:
156
  # deal with all tokens
157
+ p.join()
158
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
  with self.flow_hift_context:
160
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
cosyvoice/flow/flow_matching.py CHANGED
@@ -77,10 +77,10 @@ class ConditionalCFM(BASECFM):
77
  sol = []
78
 
79
  for step in range(1, len(t_span)):
80
- dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if self.inference_cfg_rate > 0:
83
- cfg_dphi_dt = self.forward_estimator(
84
  x, mask,
85
  torch.zeros_like(mu), t,
86
  torch.zeros_like(spks) if spks is not None else None,
@@ -96,14 +96,6 @@ class ConditionalCFM(BASECFM):
96
 
97
  return sol[-1]
98
 
99
- # TODO
100
- def forward_estimator(self):
101
- if isinstance(self.estimator, trt):
102
- assert self.training is False, 'tensorrt cannot be used in training'
103
- return xxx
104
- else:
105
- return self.estimator.forward
106
-
107
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
108
  """Computes diffusion loss
109
 
 
77
  sol = []
78
 
79
  for step in range(1, len(t_span)):
80
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if self.inference_cfg_rate > 0:
83
+ cfg_dphi_dt = self.estimator(
84
  x, mask,
85
  torch.zeros_like(mu), t,
86
  torch.zeros_like(spks) if spks is not None else None,
 
96
 
97
  return sol[-1]
98
 
 
 
 
 
 
 
 
 
99
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
100
  """Computes diffusion loss
101