liubaiji commited on
Commit
9e0b99e
1 Parent(s): df653f1

[feature] fix badcase, add fade on speech output

Browse files
Files changed (1) hide show
  1. cosyvoice/cli/model.py +9 -1
cosyvoice/cli/model.py CHANGED
@@ -49,6 +49,7 @@ class CosyVoiceModel:
49
  self.llm_end_dict = {}
50
  self.mel_overlap_dict = {}
51
  self.hift_cache_dict = {}
 
52
 
53
  def load(self, llm_model, flow_model, hift_model):
54
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -113,10 +114,17 @@ class CosyVoiceModel:
113
  self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
114
  tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
115
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
116
- self.hift_cache_dict[uuid] = {'source': tts_source[:, :, -self.source_cache_len:], 'mel': tts_mel[:, :, -self.mel_cache_len:]}
 
 
 
 
 
117
  tts_speech = tts_speech[:, :-self.source_cache_len]
118
  else:
119
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
 
 
120
  return tts_speech
121
 
122
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
 
49
  self.llm_end_dict = {}
50
  self.mel_overlap_dict = {}
51
  self.hift_cache_dict = {}
52
+ self.speech_window = np.hamming(2 * self.source_cache_len)
53
 
54
  def load(self, llm_model, flow_model, hift_model):
55
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
 
114
  self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
115
  tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
116
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
117
+ if self.hift_cache_dict[uuid] is not None:
118
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
119
+ self.hift_cache_dict[uuid] = {
120
+ 'mel': tts_mel[:, :, -self.mel_cache_len:],
121
+ 'source': tts_source[:, :, -self.source_cache_len:],
122
+ 'speech': tts_speech[:, -self.source_cache_len:]}
123
  tts_speech = tts_speech[:, :-self.source_cache_len]
124
  else:
125
  tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source)
126
+ if self.hift_cache_dict[uuid] is not None:
127
+ tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
128
  return tts_speech
129
 
130
  def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),