CosyVoice commited on
Commit
f4e70e2
1 Parent(s): 02f941d

update stream code

Browse files
.gitignore CHANGED
@@ -43,6 +43,8 @@ compile_commands.json
43
 
44
  # train/inference files
45
  *.wav
 
 
46
  *.pt
47
  pretrained_models/*
48
  *_pb2_grpc.py
 
43
 
44
  # train/inference files
45
  *.wav
46
+ *.m4a
47
+ *.aac
48
  *.pt
49
  pretrained_models/*
50
  *_pb2_grpc.py
README.md CHANGED
@@ -86,23 +86,24 @@ import torchaudio
86
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
87
  # sft usage
88
  print(cosyvoice.list_avaliable_spks())
89
- output = cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女')
90
- torchaudio.save('sft.wav', output['tts_speech'], 22050)
 
91
 
92
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
93
  # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
94
  prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
95
- output = cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k)
96
- torchaudio.save('zero_shot.wav', output['tts_speech'], 22050)
97
  # cross_lingual usage
98
  prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
99
- output = cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k)
100
- torchaudio.save('cross_lingual.wav', output['tts_speech'], 22050)
101
 
102
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
103
  # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
104
- output = cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
105
- torchaudio.save('instruct.wav', output['tts_speech'], 22050)
106
  ```
107
 
108
  **Start web demo**
@@ -133,10 +134,10 @@ docker build -t cosyvoice:v1.0 .
133
  # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
134
  # for grpc usage
135
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
136
- python3 grpc/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
137
  # for fastapi usage
138
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
139
- python3 fastapi/client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
140
  ```
141
 
142
  ## Discussion & Communication
 
86
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
87
  # sft usage
88
  print(cosyvoice.list_avaliable_spks())
89
+ # change stream=True for chunk stream inference
90
+ for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
91
+ torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)
92
 
93
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
94
  # zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
95
  prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
96
+ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
97
+ torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], 22050)
98
  # cross_lingual usage
99
  prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
100
+ for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
101
+ torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
102
 
103
  cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
104
  # instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
105
+ for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男', 'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.', stream=False)):
106
+ torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], 22050)
107
  ```
108
 
109
  **Start web demo**
 
134
  # change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
135
  # for grpc usage
136
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
137
+ cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
138
  # for fastapi usage
139
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
140
+ cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
141
  ```
142
 
143
  ## Discussion & Communication
cosyvoice/bin/inference.py CHANGED
@@ -100,10 +100,13 @@ def main():
100
  'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
- model_output = model.inference(**model_input)
 
 
 
104
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
105
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
106
- torchaudio.save(tts_fn, model_output['tts_speech'], sample_rate=22050)
107
  f.write('{} {}\n'.format(tts_key, tts_fn))
108
  f.flush()
109
  f.close()
 
100
  'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
101
  'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
102
  'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
103
+ tts_speeches = []
104
+ for model_output in model.inference(**model_input):
105
+ tts_speeches.append(model_output['tts_speech'])
106
+ tts_speeches = torch.concat(tts_speeches, dim=1)
107
  tts_key = '{}_{}'.format(utts[0], tts_index[0])
108
  tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
109
+ torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
110
  f.write('{} {}\n'.format(tts_key, tts_fn))
111
  f.flush()
112
  f.close()
cosyvoice/cli/cosyvoice.py CHANGED
@@ -49,6 +49,7 @@ class CosyVoice:
49
  for i in self.frontend.text_normalize(tts_text, split=True):
50
  model_input = self.frontend.frontend_sft(i, spk_id)
51
  start_time = time.time()
 
52
  for model_output in self.model.inference(**model_input, stream=stream):
53
  speech_len = model_output['tts_speech'].shape[1] / 22050
54
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -60,6 +61,7 @@ class CosyVoice:
60
  for i in self.frontend.text_normalize(tts_text, split=True):
61
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
62
  start_time = time.time()
 
63
  for model_output in self.model.inference(**model_input, stream=stream):
64
  speech_len = model_output['tts_speech'].shape[1] / 22050
65
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -72,6 +74,7 @@ class CosyVoice:
72
  for i in self.frontend.text_normalize(tts_text, split=True):
73
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
74
  start_time = time.time()
 
75
  for model_output in self.model.inference(**model_input, stream=stream):
76
  speech_len = model_output['tts_speech'].shape[1] / 22050
77
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
@@ -85,6 +88,7 @@ class CosyVoice:
85
  for i in self.frontend.text_normalize(tts_text, split=True):
86
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
87
  start_time = time.time()
 
88
  for model_output in self.model.inference(**model_input, stream=stream):
89
  speech_len = model_output['tts_speech'].shape[1] / 22050
90
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
 
49
  for i in self.frontend.text_normalize(tts_text, split=True):
50
  model_input = self.frontend.frontend_sft(i, spk_id)
51
  start_time = time.time()
52
+ logging.info('synthesis text {}'.format(i))
53
  for model_output in self.model.inference(**model_input, stream=stream):
54
  speech_len = model_output['tts_speech'].shape[1] / 22050
55
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
 
61
  for i in self.frontend.text_normalize(tts_text, split=True):
62
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
63
  start_time = time.time()
64
+ logging.info('synthesis text {}'.format(i))
65
  for model_output in self.model.inference(**model_input, stream=stream):
66
  speech_len = model_output['tts_speech'].shape[1] / 22050
67
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
 
74
  for i in self.frontend.text_normalize(tts_text, split=True):
75
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
76
  start_time = time.time()
77
+ logging.info('synthesis text {}'.format(i))
78
  for model_output in self.model.inference(**model_input, stream=stream):
79
  speech_len = model_output['tts_speech'].shape[1] / 22050
80
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
 
88
  for i in self.frontend.text_normalize(tts_text, split=True):
89
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
90
  start_time = time.time()
91
+ logging.info('synthesis text {}'.format(i))
92
  for model_output in self.model.inference(**model_input, stream=stream):
93
  speech_len = model_output['tts_speech'].shape[1] / 22050
94
  logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
cosyvoice/cli/model.py CHANGED
@@ -16,6 +16,8 @@ import numpy as np
16
  import threading
17
  import time
18
  from contextlib import nullcontext
 
 
19
 
20
 
21
  class CosyVoiceModel:
@@ -28,13 +30,19 @@ class CosyVoiceModel:
28
  self.llm = llm
29
  self.flow = flow
30
  self.hift = hift
31
- self.stream_win_len = 60 * 4
32
- self.stream_hop_len = 50 * 4
33
- self.overlap = 4395 * 4 # 10 token equals 4395 sample point
34
- self.window = np.hamming(2 * self.overlap)
 
 
 
35
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
36
  self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
37
  self.lock = threading.Lock()
 
 
 
38
 
39
  def load(self, llm_model, flow_model, hift_model):
40
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -44,7 +52,7 @@ class CosyVoiceModel:
44
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
45
  self.hift.to(self.device).eval()
46
 
47
- def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding):
48
  with self.llm_context:
49
  for i in self.llm.inference(text=text.to(self.device),
50
  text_len=text_len.to(self.device),
@@ -53,13 +61,11 @@ class CosyVoiceModel:
53
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
54
  prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
55
  embedding=llm_embedding.to(self.device),
56
- beam_size=1,
57
  sampling=25,
58
  max_token_text_ratio=30,
59
- min_token_text_ratio=3,
60
- stream=True):
61
- self.tts_speech_token.append(i)
62
- self.llm_end = True
63
 
64
  def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
65
  with self.flow_hift_context:
@@ -78,15 +84,19 @@ class CosyVoiceModel:
78
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
79
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
80
  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
 
 
 
 
 
 
 
81
  if stream is True:
82
- self.tts_speech_token, self.llm_end, cache_speech = [], False, None
83
- p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
84
- llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device)))
85
- p.start()
86
  while True:
87
  time.sleep(0.1)
88
- if len(self.tts_speech_token) >= self.stream_win_len:
89
- this_tts_speech_token = torch.concat(self.tts_speech_token[:self.stream_win_len], dim=1)
90
  with self.flow_hift_context:
91
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
92
  prompt_token=flow_prompt_speech_token.to(self.device),
@@ -96,57 +106,48 @@ class CosyVoiceModel:
96
  embedding=flow_embedding.to(self.device))
97
  # fade in/out if necessary
98
  if cache_speech is not None:
99
- this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
100
- yield {'tts_speech': this_tts_speech[:, :-self.overlap]}
101
- cache_speech = this_tts_speech[:, -self.overlap:]
 
102
  with self.lock:
103
- self.tts_speech_token = self.tts_speech_token[self.stream_hop_len:]
104
- if self.llm_end is True:
 
 
105
  break
106
- # deal with remain tokens
107
- if cache_speech is None or len(self.tts_speech_token) > self.stream_win_len - self.stream_hop_len:
108
- this_tts_speech_token = torch.concat(self.tts_speech_token, dim=1)
109
- with self.flow_hift_context:
110
- this_tts_mel = self.flow.inference(token=this_tts_speech_token,
111
- token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
112
- prompt_token=flow_prompt_speech_token.to(self.device),
113
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
114
- prompt_feat=prompt_speech_feat.to(self.device),
115
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
116
- embedding=flow_embedding.to(self.device))
117
- this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
118
- if cache_speech is not None:
119
- this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
120
- yield {'tts_speech': this_tts_speech}
121
- else:
122
- assert len(self.tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
123
- yield {'tts_speech': cache_speech}
124
  p.join()
125
- torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
- tts_speech_token = []
128
- for i in self.llm.inference(text=text.to(self.device),
129
- text_len=text_len.to(self.device),
130
- prompt_text=prompt_text.to(self.device),
131
- prompt_text_len=prompt_text_len.to(self.device),
132
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
133
- prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
134
- embedding=llm_embedding.to(self.device),
135
- beam_size=1,
136
- sampling=25,
137
- max_token_text_ratio=30,
138
- min_token_text_ratio=3,
139
- stream=stream):
140
- tts_speech_token.append(i)
141
- assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream)
142
- tts_speech_token = torch.concat(tts_speech_token, dim=1)
143
- tts_mel = self.flow.inference(token=tts_speech_token,
144
- token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
145
- prompt_token=flow_prompt_speech_token.to(self.device),
146
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
147
- prompt_feat=prompt_speech_feat.to(self.device),
148
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
149
- embedding=flow_embedding.to(self.device))
150
- tts_speech = self.hift.inference(mel=tts_mel).cpu()
151
- torch.cuda.empty_cache()
152
- yield {'tts_speech': tts_speech}
 
16
  import threading
17
  import time
18
  from contextlib import nullcontext
19
+ import uuid
20
+ from cosyvoice.utils.common import fade_in_out
21
 
22
 
23
  class CosyVoiceModel:
 
30
  self.llm = llm
31
  self.flow = flow
32
  self.hift = hift
33
+ self.token_min_hop_len = 100
34
+ self.token_max_hop_len = 400
35
+ self.token_overlap_len = 20
36
+ self.speech_overlap_len = 34 * 256
37
+ self.window = np.hamming(2 * self.speech_overlap_len)
38
+ self.stream_scale_factor = 1
39
+ assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
40
  self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
41
  self.flow_hift_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
42
  self.lock = threading.Lock()
43
+ # dict used to store session related variable
44
+ self.tts_speech_token = {}
45
+ self.llm_end = {}
46
 
47
  def load(self, llm_model, flow_model, hift_model):
48
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
 
52
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
53
  self.hift.to(self.device).eval()
54
 
55
+ def llm_job(self, text, text_len, prompt_text, prompt_text_len, llm_prompt_speech_token, llm_prompt_speech_token_len, llm_embedding, this_uuid):
56
  with self.llm_context:
57
  for i in self.llm.inference(text=text.to(self.device),
58
  text_len=text_len.to(self.device),
 
61
  prompt_speech_token=llm_prompt_speech_token.to(self.device),
62
  prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
63
  embedding=llm_embedding.to(self.device),
 
64
  sampling=25,
65
  max_token_text_ratio=30,
66
+ min_token_text_ratio=3):
67
+ self.tts_speech_token[this_uuid].append(i)
68
+ self.llm_end[this_uuid] = True
 
69
 
70
  def token2wav(self, token, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, embedding):
71
  with self.flow_hift_context:
 
84
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
85
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
86
  prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
87
+ # this_uuid is used to track variables related to this inference thread
88
+ this_uuid = str(uuid.uuid1())
89
+ with self.lock:
90
+ self.tts_speech_token[this_uuid], self.llm_end[this_uuid] = [], False
91
+ p = threading.Thread(target=self.llm_job, args=(text.to(self.device), text_len.to(self.device), prompt_text.to(self.device), prompt_text_len.to(self.device),
92
+ llm_prompt_speech_token.to(self.device), llm_prompt_speech_token_len.to(self.device), llm_embedding.to(self.device), this_uuid))
93
+ p.start()
94
  if stream is True:
95
+ cache_speech, cache_token, token_hop_len = None, None, self.token_min_hop_len
 
 
 
96
  while True:
97
  time.sleep(0.1)
98
+ if len(self.tts_speech_token[this_uuid]) >= token_hop_len + self.token_overlap_len:
99
+ this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
100
  with self.flow_hift_context:
101
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
102
  prompt_token=flow_prompt_speech_token.to(self.device),
 
106
  embedding=flow_embedding.to(self.device))
107
  # fade in/out if necessary
108
  if cache_speech is not None:
109
+ this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
110
+ yield {'tts_speech': this_tts_speech[:, :-self.speech_overlap_len]}
111
+ cache_speech = this_tts_speech[:, -self.speech_overlap_len:]
112
+ cache_token = self.tts_speech_token[this_uuid][:token_hop_len]
113
  with self.lock:
114
+ self.tts_speech_token[this_uuid] = self.tts_speech_token[this_uuid][token_hop_len:]
115
+ # increase token_hop_len for better speech quality
116
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
117
+ if self.llm_end[this_uuid] is True and len(self.tts_speech_token[this_uuid]) < token_hop_len + self.token_overlap_len:
118
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  p.join()
120
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
121
+ this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
122
+ if this_tts_speech_token.shape[1] < self.token_min_hop_len + self.token_overlap_len and cache_token is not None:
123
+ cache_token_len = self.token_min_hop_len + self.token_overlap_len - this_tts_speech_token.shape[1]
124
+ this_tts_speech_token = torch.concat([torch.concat(cache_token[-cache_token_len:], dim=1), this_tts_speech_token], dim=1)
125
+ else:
126
+ cache_token_len = 0
127
+ with self.flow_hift_context:
128
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
129
+ prompt_token=flow_prompt_speech_token.to(self.device),
130
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
131
+ prompt_feat=prompt_speech_feat.to(self.device),
132
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
133
+ embedding=flow_embedding.to(self.device))
134
+ this_tts_speech = this_tts_speech[:, int(cache_token_len / this_tts_speech_token.shape[1] * this_tts_speech.shape[1]):]
135
+ if cache_speech is not None:
136
+ this_tts_speech = fade_in_out(this_tts_speech, cache_speech, self.window)
137
+ yield {'tts_speech': this_tts_speech}
138
  else:
139
+ # deal with all tokens
140
+ p.join()
141
+ this_tts_speech_token = torch.concat(self.tts_speech_token[this_uuid], dim=1)
142
+ with self.flow_hift_context:
143
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
144
+ prompt_token=flow_prompt_speech_token.to(self.device),
145
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
146
+ prompt_feat=prompt_speech_feat.to(self.device),
147
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
148
+ embedding=flow_embedding.to(self.device))
149
+ yield {'tts_speech': this_tts_speech}
150
+ with self.lock:
151
+ self.tts_speech_token.pop(this_uuid)
152
+ self.llm_end.pop(this_uuid)
153
+ torch.cuda.synchronize()
 
 
 
 
 
 
 
 
 
 
 
cosyvoice/flow/flow.py CHANGED
@@ -105,6 +105,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
105
  embedding = self.spk_embed_affine_layer(embedding)
106
 
107
  # concat text and prompt_text
 
108
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
109
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
110
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
@@ -112,17 +113,16 @@ class MaskedDiffWithXvec(torch.nn.Module):
112
  # text encode
113
  h, h_lengths = self.encoder(token, token_len)
114
  h = self.encoder_proj(h)
115
- feat_len = (token_len / 50 * 22050 / 256).int()
116
- h, h_lengths = self.length_regulator(h, feat_len)
117
 
118
  # get conditions
119
- conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
120
- if prompt_feat.shape[1] != 0:
121
- for i, j in enumerate(prompt_feat_len):
122
- conds[i, :j] = prompt_feat[i]
123
  conds = conds.transpose(1, 2)
124
 
125
- mask = (~make_pad_mask(feat_len)).to(h)
 
126
  feat = self.decoder(
127
  mu=h.transpose(1, 2).contiguous(),
128
  mask=mask.unsqueeze(1),
@@ -130,6 +130,6 @@ class MaskedDiffWithXvec(torch.nn.Module):
130
  cond=conds,
131
  n_timesteps=10
132
  )
133
- if prompt_feat.shape[1] != 0:
134
- feat = feat[:, :, prompt_feat.shape[1]:]
135
  return feat
 
105
  embedding = self.spk_embed_affine_layer(embedding)
106
 
107
  # concat text and prompt_text
108
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
109
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
110
  mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
111
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
 
113
  # text encode
114
  h, h_lengths = self.encoder(token, token_len)
115
  h = self.encoder_proj(h)
116
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
117
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
118
 
119
  # get conditions
120
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
121
+ conds[:, :mel_len1] = prompt_feat
 
 
122
  conds = conds.transpose(1, 2)
123
 
124
+ # mask = (~make_pad_mask(feat_len)).to(h)
125
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
126
  feat = self.decoder(
127
  mu=h.transpose(1, 2).contiguous(),
128
  mask=mask.unsqueeze(1),
 
130
  cond=conds,
131
  n_timesteps=10
132
  )
133
+ feat = feat[:, :, mel_len1:]
134
+ assert feat.shape[2] == mel_len2
135
  return feat
cosyvoice/flow/length_regulator.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  from typing import Tuple
15
  import torch.nn as nn
 
16
  from torch.nn import functional as F
17
  from cosyvoice.utils.mask import make_pad_mask
18
 
@@ -47,3 +48,21 @@ class InterpolateRegulator(nn.Module):
47
  out = self.model(x).transpose(1, 2).contiguous()
48
  olens = ylens
49
  return out * mask, olens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  from typing import Tuple
15
  import torch.nn as nn
16
+ import torch
17
  from torch.nn import functional as F
18
  from cosyvoice.utils.mask import make_pad_mask
19
 
 
48
  out = self.model(x).transpose(1, 2).contiguous()
49
  olens = ylens
50
  return out * mask, olens
51
+
52
+ def inference(self, x1, x2, mel_len1, mel_len2):
53
+ # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
+ # x in (B, T, D)
55
+ if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
58
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
59
+ x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
60
+ else:
61
+ x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
62
+ if x1.shape[1] != 0:
63
+ x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
64
+ x = torch.concat([x1, x2], dim=2)
65
+ else:
66
+ x = x2
67
+ out = self.model(x).transpose(1, 2).contiguous()
68
+ return out, mel_len1 + mel_len2
cosyvoice/llm/llm.py CHANGED
@@ -11,7 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- from typing import Dict, Optional, Union
15
  import torch
16
  from torch import nn
17
  import torch.nn.functional as F
@@ -31,6 +31,7 @@ class TransformerLM(torch.nn.Module):
31
  speech_token_size: int,
32
  text_encoder: torch.nn.Module,
33
  llm: torch.nn.Module,
 
34
  length_normalized_loss: bool = True,
35
  lsm_weight: float = 0.0,
36
  spk_embed_dim: int = 192,
@@ -63,6 +64,9 @@ class TransformerLM(torch.nn.Module):
63
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
64
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
65
 
 
 
 
66
  def encode(
67
  self,
68
  text: torch.Tensor,
@@ -132,14 +136,12 @@ class TransformerLM(torch.nn.Module):
132
  def sampling_ids(
133
  self,
134
  weighted_scores: torch.Tensor,
135
- sampling: Union[bool, int, float] = True,
136
- beam_size: int = 1,
137
  ignore_eos: bool = True,
138
  ):
139
  while True:
140
- prob, indices = weighted_scores.softmax(dim=-1).topk(sampling)
141
- top_ids = prob.multinomial(beam_size, replacement=True)
142
- top_ids = indices[top_ids]
143
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
144
  break
145
  return top_ids
@@ -154,12 +156,10 @@ class TransformerLM(torch.nn.Module):
154
  prompt_speech_token: torch.Tensor,
155
  prompt_speech_token_len: torch.Tensor,
156
  embedding: torch.Tensor,
157
- beam_size: int = 1,
158
  sampling: int = 25,
159
  max_token_text_ratio: float = 20,
160
  min_token_text_ratio: float = 2,
161
- stream: bool = False,
162
- ) -> torch.Tensor:
163
  device = text.device
164
  text = torch.concat([prompt_text, text], dim=1)
165
  text_len += prompt_text_len
@@ -197,16 +197,11 @@ class TransformerLM(torch.nn.Module):
197
  y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
198
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
199
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
200
- top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
201
  if top_ids == self.speech_token_size:
202
  break
203
  # in stream mode, yield token one by one
204
- if stream is True:
205
- yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
206
  out_tokens.append(top_ids)
207
  offset += lm_input.size(1)
208
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
209
-
210
- # in non-stream mode, yield all token
211
- if stream is False:
212
- yield torch.tensor([out_tokens], dtype=torch.int64, device=device)
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ from typing import Dict, Optional, Callable, List, Generator
15
  import torch
16
  from torch import nn
17
  import torch.nn.functional as F
 
31
  speech_token_size: int,
32
  text_encoder: torch.nn.Module,
33
  llm: torch.nn.Module,
34
+ sampling: Callable,
35
  length_normalized_loss: bool = True,
36
  lsm_weight: float = 0.0,
37
  spk_embed_dim: int = 192,
 
64
  self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
65
  self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
66
 
67
+ # 4. sampling method
68
+ self.sampling = sampling
69
+
70
  def encode(
71
  self,
72
  text: torch.Tensor,
 
136
  def sampling_ids(
137
  self,
138
  weighted_scores: torch.Tensor,
139
+ decoded_tokens: List,
140
+ sampling: int,
141
  ignore_eos: bool = True,
142
  ):
143
  while True:
144
+ top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
 
 
145
  if (not ignore_eos) or (self.speech_token_size not in top_ids):
146
  break
147
  return top_ids
 
156
  prompt_speech_token: torch.Tensor,
157
  prompt_speech_token_len: torch.Tensor,
158
  embedding: torch.Tensor,
 
159
  sampling: int = 25,
160
  max_token_text_ratio: float = 20,
161
  min_token_text_ratio: float = 2,
162
+ ) -> Generator[torch.Tensor, None, None]:
 
163
  device = text.device
164
  text = torch.concat([prompt_text, text], dim=1)
165
  text_len += prompt_text_len
 
197
  y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
198
  att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
199
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
200
+ top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
201
  if top_ids == self.speech_token_size:
202
  break
203
  # in stream mode, yield token one by one
204
+ yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
 
205
  out_tokens.append(top_ids)
206
  offset += lm_input.size(1)
207
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
 
 
 
 
cosyvoice/utils/common.py CHANGED
@@ -101,3 +101,37 @@ def init_weights(m, mean=0.0, std=0.01):
101
  classname = m.__class__.__name__
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  classname = m.__class__.__name__
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
104
+
105
+ # Repetition Aware Sampling in VALL-E 2
106
+ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
107
+ top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
108
+ rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
109
+ if rep_num >= win_size * tau_r:
110
+ top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
111
+ return top_ids
112
+
113
+ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
114
+ prob, indices = [], []
115
+ cum_prob = 0.0
116
+ sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
117
+ for i in range(len(sorted_idx)):
118
+ # sampling both top-p and numbers.
119
+ if cum_prob < top_p and len(prob) < top_k:
120
+ cum_prob += sorted_value[i]
121
+ prob.append(sorted_value[i])
122
+ indices.append(sorted_idx[i])
123
+ else:
124
+ break
125
+ prob = torch.tensor(prob).to(weighted_scores)
126
+ indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
127
+ top_ids = indices[prob.multinomial(1, replacement=True)]
128
+ return top_ids
129
+
130
+ def random_sampling(weighted_scores, decoded_tokens, sampling):
131
+ top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
+ return top_ids
133
+
134
+ def fade_in_out(fade_in_speech, fade_out_speech, window):
135
+ speech_overlap_len = int(window.shape[0] / 2)
136
+ fade_in_speech[:, :speech_overlap_len] = fade_in_speech[:, :speech_overlap_len] * window[:speech_overlap_len] + fade_out_speech[:, -speech_overlap_len:] * window[speech_overlap_len:]
137
+ return fade_in_speech
examples/libritts/cosyvoice/conf/cosyvoice.fromscratch.yaml CHANGED
@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
 
 
 
 
 
57
 
58
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
  input_size: 512
 
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
57
+ sampling: !name:cosyvoice.utils.common.ras_sampling
58
+ top_p: 0.8
59
+ top_k: 25
60
+ win_size: 10
61
+ tau_r: 0.1
62
 
63
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
64
  input_size: 512
examples/libritts/cosyvoice/conf/cosyvoice.yaml CHANGED
@@ -54,6 +54,11 @@ llm: !new:cosyvoice.llm.llm.TransformerLM
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
 
 
 
 
 
57
 
58
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
59
  input_size: 512
 
54
  pos_enc_layer_type: 'rel_pos_espnet'
55
  selfattention_layer_type: 'rel_selfattn'
56
  static_chunk_size: 1
57
+ sampling: !name:cosyvoice.utils.common.ras_sampling
58
+ top_p: 0.8
59
+ top_k: 25
60
+ win_size: 10
61
+ tau_r: 0.1
62
 
63
  flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
64
  input_size: 512
runtime/python/grpc/client.py CHANGED
@@ -61,8 +61,11 @@ def main():
61
  request.instruct_request.CopyFrom(instruct_request)
62
 
63
  response = stub.Inference(request)
 
 
 
 
64
  logging.info('save response to {}'.format(args.tts_wav))
65
- tts_speech = torch.from_numpy(np.array(np.frombuffer(response.tts_audio, dtype=np.int16))).unsqueeze(dim=0)
66
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
67
  logging.info('get response')
68
 
 
61
  request.instruct_request.CopyFrom(instruct_request)
62
 
63
  response = stub.Inference(request)
64
+ tts_audio = b''
65
+ for r in response:
66
+ tts_audio += r.tts_audio
67
+ tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
68
  logging.info('save response to {}'.format(args.tts_wav))
 
69
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
70
  logging.info('get response')
71
 
runtime/python/grpc/cosyvoice.proto CHANGED
@@ -4,7 +4,7 @@ package cosyvoice;
4
  option go_package = "protos/";
5
 
6
  service CosyVoice{
7
- rpc Inference(Request) returns (Response) {}
8
  }
9
 
10
  message Request{
 
4
  option go_package = "protos/";
5
 
6
  service CosyVoice{
7
+ rpc Inference(Request) returns (stream Response) {}
8
  }
9
 
10
  message Request{
runtime/python/grpc/server.py CHANGED
@@ -54,9 +54,10 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
54
  model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
55
 
56
  logging.info('send inference response')
57
- response = cosyvoice_pb2.Response()
58
- response.tts_audio = (model_output['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
59
- return response
 
60
 
61
  def main():
62
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
 
54
  model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
55
 
56
  logging.info('send inference response')
57
+ for i in model_output:
58
+ response = cosyvoice_pb2.Response()
59
+ response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
60
+ yield response
61
 
62
  def main():
63
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
webui.py CHANGED
@@ -164,7 +164,7 @@ def main():
164
  outputs=[audio_output])
165
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
166
  demo.queue(max_size=4, default_concurrency_limit=2)
167
- demo.launch(server_port=args.port)
168
 
169
  if __name__ == '__main__':
170
  parser = argparse.ArgumentParser()
 
164
  outputs=[audio_output])
165
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
166
  demo.queue(max_size=4, default_concurrency_limit=2)
167
+ demo.launch(server_name='0.0.0.0', server_port=args.port)
168
 
169
  if __name__ == '__main__':
170
  parser = argparse.ArgumentParser()