CosyVoice commited on
Commit
7f5e391
2 Parent(s): 7795445 e141634

Merge pull request #353 from FunAudioLLM/inference_streaming

Browse files
README.md CHANGED
@@ -167,7 +167,7 @@ docker build -t cosyvoice:v1.0 .
167
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
168
  cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
169
  # for fastapi usage
170
- docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && MODEL_DIR=iic/CosyVoice-300M fastapi dev --port 50000 server.py && sleep infinity"
171
  cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
172
  ```
173
 
 
167
  docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
168
  cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
169
  # for fastapi usage
170
+ docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
171
  cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
172
  ```
173
 
cosyvoice/bin/export_jit.py CHANGED
@@ -44,7 +44,7 @@ def main():
44
  torch._C._jit_set_profiling_mode(False)
45
  torch._C._jit_set_profiling_executor(False)
46
 
47
- cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_trt=False)
48
 
49
  # 1. export llm text_encoder
50
  llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
@@ -60,5 +60,12 @@ def main():
60
  script = torch.jit.optimize_for_inference(script)
61
  script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
62
 
 
 
 
 
 
 
 
63
  if __name__ == '__main__':
64
  main()
 
44
  torch._C._jit_set_profiling_mode(False)
45
  torch._C._jit_set_profiling_executor(False)
46
 
47
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
48
 
49
  # 1. export llm text_encoder
50
  llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
 
60
  script = torch.jit.optimize_for_inference(script)
61
  script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
62
 
63
+ # 3. export flow encoder
64
+ flow_encoder = cosyvoice.model.flow.encoder
65
+ script = torch.jit.script(flow_encoder)
66
+ script = torch.jit.freeze(script)
67
+ script = torch.jit.optimize_for_inference(script)
68
+ script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
69
+
70
  if __name__ == '__main__':
71
  main()
cosyvoice/bin/export_onnx.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import print_function
17
+
18
+ import argparse
19
+ import logging
20
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
+ import os
22
+ import sys
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/../..'.format(ROOT_DIR))
25
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
+ import onnxruntime
27
+ import random
28
+ import torch
29
+ from tqdm import tqdm
30
+ from cosyvoice.cli.cosyvoice import CosyVoice
31
+
32
+
33
+ def get_dummy_input(batch_size, seq_len, out_channels, device):
34
+ x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35
+ mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36
+ mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37
+ t = torch.rand((batch_size), dtype=torch.float32, device=device)
38
+ spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39
+ cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40
+ return x, mask, mu, t, spks, cond
41
+
42
+
43
+ def get_args():
44
+ parser = argparse.ArgumentParser(description='export your model for deployment')
45
+ parser.add_argument('--model_dir',
46
+ type=str,
47
+ default='pretrained_models/CosyVoice-300M',
48
+ help='local path')
49
+ args = parser.parse_args()
50
+ print(args)
51
+ return args
52
+
53
+ def main():
54
+ args = get_args()
55
+ logging.basicConfig(level=logging.DEBUG,
56
+ format='%(asctime)s %(levelname)s %(message)s')
57
+
58
+ cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
59
+
60
+ # 1. export flow decoder estimator
61
+ estimator = cosyvoice.model.flow.decoder.estimator
62
+
63
+ device = cosyvoice.model.device
64
+ batch_size, seq_len = 1, 256
65
+ out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
66
+ x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
67
+ torch.onnx.export(
68
+ estimator,
69
+ (x, mask, mu, t, spks, cond),
70
+ '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
71
+ export_params=True,
72
+ opset_version=18,
73
+ do_constant_folding=True,
74
+ input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
75
+ output_names=['estimator_out'],
76
+ dynamic_axes={
77
+ 'x': {0: 'batch_size', 2: 'seq_len'},
78
+ 'mask': {0: 'batch_size', 2: 'seq_len'},
79
+ 'mu': {0: 'batch_size', 2: 'seq_len'},
80
+ 'cond': {0: 'batch_size', 2: 'seq_len'},
81
+ 't': {0: 'batch_size'},
82
+ 'spks': {0: 'batch_size'},
83
+ 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
84
+ }
85
+ )
86
+
87
+ # 2. test computation consistency
88
+ option = onnxruntime.SessionOptions()
89
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
90
+ option.intra_op_num_threads = 1
91
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
92
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
93
+
94
+ for _ in tqdm(range(10)):
95
+ x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
96
+ output_pytorch = estimator(x, mask, mu, t, spks, cond)
97
+ ort_inputs = {
98
+ 'x': x.cpu().numpy(),
99
+ 'mask': mask.cpu().numpy(),
100
+ 'mu': mu.cpu().numpy(),
101
+ 't': t.cpu().numpy(),
102
+ 'spks': spks.cpu().numpy(),
103
+ 'cond': cond.cpu().numpy()
104
+ }
105
+ output_onnx = estimator_onnx.run(None, ort_inputs)[0]
106
+ torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
107
+
108
+ if __name__ == "__main__":
109
+ main()
cosyvoice/bin/export_trt.py DELETED
@@ -1,8 +0,0 @@
1
- # TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
2
- # tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
3
- try:
4
- import tensorrt
5
- except ImportError:
6
- print('step1, 下载\n step2. 解压,安装whl,')
7
- # 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
8
- # 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
 
 
 
 
 
 
 
 
 
cosyvoice/cli/cosyvoice.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  import os
15
  import time
 
16
  from hyperpyyaml import load_hyperpyyaml
17
  from modelscope import snapshot_download
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
@@ -21,7 +22,7 @@ from cosyvoice.utils.file_utils import logging
21
 
22
  class CosyVoice:
23
 
24
- def __init__(self, model_dir, load_jit=True):
25
  instruct = True if '-Instruct' in model_dir else False
26
  self.model_dir = model_dir
27
  if not os.path.exists(model_dir):
@@ -41,7 +42,10 @@ class CosyVoice:
41
  '{}/hift.pt'.format(model_dir))
42
  if load_jit:
43
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
44
- '{}/llm.llm.fp16.zip'.format(model_dir))
 
 
 
45
  del configs
46
 
47
  def list_avaliable_spks(self):
@@ -49,7 +53,7 @@ class CosyVoice:
49
  return spks
50
 
51
  def inference_sft(self, tts_text, spk_id, stream=False):
52
- for i in self.frontend.text_normalize(tts_text, split=True):
53
  model_input = self.frontend.frontend_sft(i, spk_id)
54
  start_time = time.time()
55
  logging.info('synthesis text {}'.format(i))
@@ -61,7 +65,7 @@ class CosyVoice:
61
 
62
  def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
63
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
64
- for i in self.frontend.text_normalize(tts_text, split=True):
65
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
66
  start_time = time.time()
67
  logging.info('synthesis text {}'.format(i))
@@ -74,7 +78,7 @@ class CosyVoice:
74
  def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
75
  if self.frontend.instruct is True:
76
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
77
- for i in self.frontend.text_normalize(tts_text, split=True):
78
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
79
  start_time = time.time()
80
  logging.info('synthesis text {}'.format(i))
@@ -88,7 +92,7 @@ class CosyVoice:
88
  if self.frontend.instruct is False:
89
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
90
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
91
- for i in self.frontend.text_normalize(tts_text, split=True):
92
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
93
  start_time = time.time()
94
  logging.info('synthesis text {}'.format(i))
 
13
  # limitations under the License.
14
  import os
15
  import time
16
+ from tqdm import tqdm
17
  from hyperpyyaml import load_hyperpyyaml
18
  from modelscope import snapshot_download
19
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
 
22
 
23
  class CosyVoice:
24
 
25
+ def __init__(self, model_dir, load_jit=True, load_onnx=True):
26
  instruct = True if '-Instruct' in model_dir else False
27
  self.model_dir = model_dir
28
  if not os.path.exists(model_dir):
 
42
  '{}/hift.pt'.format(model_dir))
43
  if load_jit:
44
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
45
+ '{}/llm.llm.fp16.zip'.format(model_dir),
46
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
47
+ if load_onnx:
48
+ self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
49
  del configs
50
 
51
  def list_avaliable_spks(self):
 
53
  return spks
54
 
55
  def inference_sft(self, tts_text, spk_id, stream=False):
56
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
57
  model_input = self.frontend.frontend_sft(i, spk_id)
58
  start_time = time.time()
59
  logging.info('synthesis text {}'.format(i))
 
65
 
66
  def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
67
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
68
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
69
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
70
  start_time = time.time()
71
  logging.info('synthesis text {}'.format(i))
 
78
  def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
79
  if self.frontend.instruct is True:
80
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
81
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
82
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
83
  start_time = time.time()
84
  logging.info('synthesis text {}'.format(i))
 
92
  if self.frontend.instruct is False:
93
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
94
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
95
+ for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
96
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
97
  start_time = time.time()
98
  logging.info('synthesis text {}'.format(i))
cosyvoice/cli/model.py CHANGED
@@ -18,7 +18,7 @@ import time
18
  from contextlib import nullcontext
19
  import uuid
20
  from cosyvoice.utils.common import fade_in_out
21
-
22
 
23
  class CosyVoiceModel:
24
 
@@ -60,11 +60,22 @@ class CosyVoiceModel:
60
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
61
  self.hift.to(self.device).eval()
62
 
63
- def load_jit(self, llm_text_encoder_model, llm_llm_model):
64
  llm_text_encoder = torch.jit.load(llm_text_encoder_model)
65
  self.llm.text_encoder = llm_text_encoder
66
  llm_llm = torch.jit.load(llm_llm_model)
67
  self.llm.llm = llm_llm
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
70
  with self.llm_context:
@@ -169,4 +180,5 @@ class CosyVoiceModel:
169
  self.llm_end_dict.pop(this_uuid)
170
  self.mel_overlap_dict.pop(this_uuid)
171
  self.hift_cache_dict.pop(this_uuid)
172
- torch.cuda.synchronize()
 
 
18
  from contextlib import nullcontext
19
  import uuid
20
  from cosyvoice.utils.common import fade_in_out
21
+ import numpy as np
22
 
23
  class CosyVoiceModel:
24
 
 
60
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
61
  self.hift.to(self.device).eval()
62
 
63
+ def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
64
  llm_text_encoder = torch.jit.load(llm_text_encoder_model)
65
  self.llm.text_encoder = llm_text_encoder
66
  llm_llm = torch.jit.load(llm_llm_model)
67
  self.llm.llm = llm_llm
68
+ flow_encoder = torch.jit.load(flow_encoder_model)
69
+ self.flow.encoder = flow_encoder
70
+
71
+ def load_onnx(self, flow_decoder_estimator_model):
72
+ import onnxruntime
73
+ option = onnxruntime.SessionOptions()
74
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
75
+ option.intra_op_num_threads = 1
76
+ providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
77
+ del self.flow.decoder.estimator
78
+ self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
79
 
80
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
81
  with self.llm_context:
 
180
  self.llm_end_dict.pop(this_uuid)
181
  self.mel_overlap_dict.pop(this_uuid)
182
  self.hift_cache_dict.pop(this_uuid)
183
+ if torch.cuda.is_available():
184
+ torch.cuda.synchronize()
cosyvoice/flow/decoder.py CHANGED
@@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module):
159
  _type_: _description_
160
  """
161
 
162
- t = self.time_embeddings(t)
163
  t = self.time_mlp(t)
164
 
165
  x = pack([x, mu], "b * t")[0]
 
159
  _type_: _description_
160
  """
161
 
162
+ t = self.time_embeddings(t).to(t.dtype)
163
  t = self.time_mlp(t)
164
 
165
  x = pack([x, mu], "b * t")[0]
cosyvoice/flow/flow.py CHANGED
@@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
113
  # concat text and prompt_text
114
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
115
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
116
- mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
117
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
118
 
119
  # text encode
 
113
  # concat text and prompt_text
114
  token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
115
  token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
116
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
117
  token = self.input_embedding(torch.clamp(token, min=0)) * mask
118
 
119
  # text encode
cosyvoice/flow/flow_matching.py CHANGED
@@ -50,7 +50,7 @@ class ConditionalCFM(BASECFM):
50
  shape: (batch_size, n_feats, mel_timesteps)
51
  """
52
  z = torch.randn_like(mu) * temperature
53
- t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
54
  if self.t_scheduler == 'cosine':
55
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
56
  return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
@@ -71,16 +71,17 @@ class ConditionalCFM(BASECFM):
71
  cond: Not used but kept for future purposes
72
  """
73
  t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
 
74
 
75
  # I am storing this because I can later plot it by putting a debugger here and saving it to a file
76
  # Or in future might add like a return_all_steps flag
77
  sol = []
78
 
79
  for step in range(1, len(t_span)):
80
- dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
81
  # Classifier-Free Guidance inference introduced in VoiceBox
82
  if self.inference_cfg_rate > 0:
83
- cfg_dphi_dt = self.estimator(
84
  x, mask,
85
  torch.zeros_like(mu), t,
86
  torch.zeros_like(spks) if spks is not None else None,
@@ -96,6 +97,21 @@ class ConditionalCFM(BASECFM):
96
 
97
  return sol[-1]
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
100
  """Computes diffusion loss
101
 
 
50
  shape: (batch_size, n_feats, mel_timesteps)
51
  """
52
  z = torch.randn_like(mu) * temperature
53
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
54
  if self.t_scheduler == 'cosine':
55
  t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
56
  return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
 
71
  cond: Not used but kept for future purposes
72
  """
73
  t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
74
+ t = t.unsqueeze(dim=0)
75
 
76
  # I am storing this because I can later plot it by putting a debugger here and saving it to a file
77
  # Or in future might add like a return_all_steps flag
78
  sol = []
79
 
80
  for step in range(1, len(t_span)):
81
+ dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
82
  # Classifier-Free Guidance inference introduced in VoiceBox
83
  if self.inference_cfg_rate > 0:
84
+ cfg_dphi_dt = self.forward_estimator(
85
  x, mask,
86
  torch.zeros_like(mu), t,
87
  torch.zeros_like(spks) if spks is not None else None,
 
97
 
98
  return sol[-1]
99
 
100
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
101
+ if isinstance(self.estimator, torch.nn.Module):
102
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
103
+ else:
104
+ ort_inputs = {
105
+ 'x': x.cpu().numpy(),
106
+ 'mask': mask.cpu().numpy(),
107
+ 'mu': mu.cpu().numpy(),
108
+ 't': t.cpu().numpy(),
109
+ 'spks': spks.cpu().numpy(),
110
+ 'cond': cond.cpu().numpy()
111
+ }
112
+ output = self.estimator.run(None, ort_inputs)[0]
113
+ return torch.tensor(output, dtype=x.dtype, device=x.device)
114
+
115
  def compute_loss(self, x1, mask, mu, spks=None, cond=None):
116
  """Computes diffusion loss
117
 
cosyvoice/hifigan/generator.py CHANGED
@@ -340,7 +340,7 @@ class HiFTGenerator(nn.Module):
340
  s = self._f02source(f0)
341
 
342
  # use cache_source to avoid glitch
343
- if cache_source.shape[2] == 0:
344
  s[:, :, :cache_source.shape[2]] = cache_source
345
 
346
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
 
340
  s = self._f02source(f0)
341
 
342
  # use cache_source to avoid glitch
343
+ if cache_source.shape[2] != 0:
344
  s[:, :, :cache_source.shape[2]] = cache_source
345
 
346
  s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
examples/libritts/cosyvoice/run.sh CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
102
  --deepspeed_config ./conf/ds_stage2.json \
103
  --deepspeed.save_states model+optimizer
104
  done
 
 
 
 
 
 
105
  fi
 
102
  --deepspeed_config ./conf/ds_stage2.json \
103
  --deepspeed.save_states model+optimizer
104
  done
105
+ fi
106
+
107
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
108
+ echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
109
+ python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
110
+ python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
111
  fi
examples/magicdata-read/cosyvoice/run.sh CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
102
  --deepspeed_config ./conf/ds_stage2.json \
103
  --deepspeed.save_states model+optimizer
104
  done
 
 
 
 
 
 
105
  fi
 
102
  --deepspeed_config ./conf/ds_stage2.json \
103
  --deepspeed.save_states model+optimizer
104
  done
105
+ fi
106
+
107
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
108
+ echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
109
+ python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
110
+ python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
111
  fi
requirements.txt CHANGED
@@ -15,6 +15,7 @@ matplotlib==3.7.5
15
  modelscope==1.15.0
16
  networkx==3.1
17
  omegaconf==2.3.0
 
18
  onnxruntime-gpu==1.16.0; sys_platform == 'linux'
19
  onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
20
  openai-whisper==20231117
@@ -25,6 +26,7 @@ soundfile==0.12.1
25
  tensorboard==2.14.0
26
  torch==2.0.1
27
  torchaudio==2.0.2
 
28
  wget==3.2
29
  fastapi==0.111.0
30
  fastapi-cli==0.0.4
 
15
  modelscope==1.15.0
16
  networkx==3.1
17
  omegaconf==2.3.0
18
+ onnx==1.16.0
19
  onnxruntime-gpu==1.16.0; sys_platform == 'linux'
20
  onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
21
  openai-whisper==20231117
 
26
  tensorboard==2.14.0
27
  torch==2.0.1
28
  torchaudio==2.0.2
29
+ uvicorn==0.30.0
30
  wget==3.2
31
  fastapi==0.111.0
32
  fastapi-cli==0.0.4
runtime/python/fastapi/client.py CHANGED
@@ -1,56 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import logging
3
  import requests
 
 
 
4
 
5
- def saveResponse(path, response):
6
- # 以二进制写入模式打开文件
7
- with open(path, 'wb') as file:
8
- # 将响应的二进制内容写入文件
9
- file.write(response.content)
10
 
11
  def main():
12
- api = args.api_base
13
  if args.mode == 'sft':
14
- url = api + "/api/inference/sft"
15
- payload={
16
- 'tts': args.tts_text,
17
- 'role': args.spk_id
18
  }
19
- response = requests.request("POST", url, data=payload)
20
- saveResponse(args.tts_wav, response)
21
  elif args.mode == 'zero_shot':
22
- url = api + "/api/inference/zero-shot"
23
- payload={
24
- 'tts': args.tts_text,
25
- 'prompt': args.prompt_text
26
  }
27
- files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
28
- response = requests.request("POST", url, data=payload, files=files)
29
- saveResponse(args.tts_wav, response)
30
  elif args.mode == 'cross_lingual':
31
- url = api + "/api/inference/cross-lingual"
32
- payload={
33
- 'tts': args.tts_text,
34
  }
35
- files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
36
- response = requests.request("POST", url, data=payload, files=files)
37
- saveResponse(args.tts_wav, response)
38
  else:
39
- url = api + "/api/inference/instruct"
40
  payload = {
41
- 'tts': args.tts_text,
42
- 'role': args.spk_id,
43
- 'instruct': args.instruct_text
44
  }
45
- response = requests.request("POST", url, data=payload)
46
- saveResponse(args.tts_wav, response)
47
- logging.info("Response save to {}", args.tts_wav)
 
 
 
 
 
48
 
49
  if __name__ == "__main__":
50
  parser = argparse.ArgumentParser()
51
- parser.add_argument('--api_base',
52
  type=str,
53
- default='http://127.0.0.1:6006')
 
 
 
54
  parser.add_argument('--mode',
55
  default='sft',
56
  choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
  import argparse
15
  import logging
16
  import requests
17
+ import torch
18
+ import torchaudio
19
+ import numpy as np
20
 
 
 
 
 
 
21
 
22
  def main():
23
+ url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode)
24
  if args.mode == 'sft':
25
+ payload = {
26
+ 'tts_text': args.tts_text,
27
+ 'spk_id': args.spk_id
 
28
  }
29
+ response = requests.request("GET", url, data=payload, stream=True)
 
30
  elif args.mode == 'zero_shot':
31
+ payload = {
32
+ 'tts_text': args.tts_text,
33
+ 'prompt_text': args.prompt_text
 
34
  }
35
+ files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
36
+ response = requests.request("GET", url, data=payload, files=files, stream=True)
 
37
  elif args.mode == 'cross_lingual':
38
+ payload = {
39
+ 'tts_text': args.tts_text,
 
40
  }
41
+ files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
42
+ response = requests.request("GET", url, data=payload, files=files, stream=True)
 
43
  else:
 
44
  payload = {
45
+ 'tts_text': args.tts_text,
46
+ 'spk_id': args.spk_id,
47
+ 'instruct_text': args.instruct_text
48
  }
49
+ response = requests.request("GET", url, data=payload, stream=True)
50
+ tts_audio = b''
51
+ for r in response.iter_content(chunk_size=16000):
52
+ tts_audio += r
53
+ tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
54
+ logging.info('save response to {}'.format(args.tts_wav))
55
+ torchaudio.save(args.tts_wav, tts_speech, target_sr)
56
+ logging.info('get response')
57
 
58
  if __name__ == "__main__":
59
  parser = argparse.ArgumentParser()
60
+ parser.add_argument('--host',
61
  type=str,
62
+ default='0.0.0.0')
63
+ parser.add_argument('--port',
64
+ type=int,
65
+ default='50000')
66
  parser.add_argument('--mode',
67
  default='sft',
68
  choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
runtime/python/fastapi/server.py CHANGED
@@ -1,119 +1,77 @@
1
- # Set inference model
2
- # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
3
- # For development
4
- # fastapi dev --port 6006 fastapi_server.py
5
- # For production deployment
6
- # fastapi run --port 6006 fastapi_server.py
7
-
 
 
 
 
 
 
8
  import os
9
  import sys
10
- import io,time
11
- from fastapi import FastAPI, Response, File, UploadFile, Form
12
- from fastapi.responses import HTMLResponse
13
- from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块
14
- from contextlib import asynccontextmanager
15
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
16
  sys.path.append('{}/../../..'.format(ROOT_DIR))
17
  sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
18
- from cosyvoice.cli.cosyvoice import CosyVoice
19
- from cosyvoice.utils.file_utils import load_wav
20
- import numpy as np
21
- import torch
22
- import torchaudio
23
  import logging
24
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 
 
 
 
 
 
25
 
26
- class LaunchFailed(Exception):
27
- pass
28
-
29
- @asynccontextmanager
30
- async def lifespan(app: FastAPI):
31
- model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
32
- if model_dir:
33
- logging.info("MODEL_DIR is {}", model_dir)
34
- app.cosyvoice = CosyVoice(model_dir)
35
- # sft usage
36
- logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
37
- else:
38
- raise LaunchFailed("MODEL_DIR environment must set")
39
- yield
40
-
41
- app = FastAPI(lifespan=lifespan)
42
-
43
- #设置允许访问的域名
44
- origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。
45
  app.add_middleware(
46
- CORSMiddleware,
47
- allow_origins=origins, #设置允许的origins来源
48
  allow_credentials=True,
49
- allow_methods=["*"], # 设置允许跨域的http方法,比如 get、post、put等。
50
- allow_headers=["*"]) #允许跨域的headers,可以用来鉴别来源等作用。
51
-
52
- def buildResponse(output):
53
- buffer = io.BytesIO()
54
- torchaudio.save(buffer, output, 22050, format="wav")
55
- buffer.seek(0)
56
- return Response(content=buffer.read(-1), media_type="audio/wav")
57
-
58
- @app.post("/api/inference/sft")
59
- @app.get("/api/inference/sft")
60
- async def sft(tts: str = Form(), role: str = Form()):
61
- start = time.process_time()
62
- output = app.cosyvoice.inference_sft(tts, role)
63
- end = time.process_time()
64
- logging.info("infer time is {} seconds", end-start)
65
- return buildResponse(output['tts_speech'])
66
-
67
- @app.post("/api/inference/zero-shot")
68
- async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
69
- start = time.process_time()
70
- prompt_speech = load_wav(audio.file, 16000)
71
- prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
72
- prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
73
- prompt_speech_16k = prompt_speech_16k.float() / (2**15)
74
 
75
- output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
76
- end = time.process_time()
77
- logging.info("infer time is {} seconds", end-start)
78
- return buildResponse(output['tts_speech'])
79
 
80
- @app.post("/api/inference/cross-lingual")
81
- async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
82
- start = time.process_time()
83
- prompt_speech = load_wav(audio.file, 16000)
84
- prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
85
- prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
86
- prompt_speech_16k = prompt_speech_16k.float() / (2**15)
87
 
88
- output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
89
- end = time.process_time()
90
- logging.info("infer time is {} seconds", end-start)
91
- return buildResponse(output['tts_speech'])
 
92
 
93
- @app.post("/api/inference/instruct")
94
- @app.get("/api/inference/instruct")
95
- async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
96
- start = time.process_time()
97
- output = app.cosyvoice.inference_instruct(tts, role, instruct)
98
- end = time.process_time()
99
- logging.info("infer time is {} seconds", end-start)
100
- return buildResponse(output['tts_speech'])
101
 
102
- @app.get("/api/roles")
103
- async def roles():
104
- return {"roles": app.cosyvoice.list_avaliable_spks()}
 
105
 
106
- @app.get("/", response_class=HTMLResponse)
107
- async def root():
108
- return """
109
- <!DOCTYPE html>
110
- <html lang=zh-cn>
111
- <head>
112
- <meta charset=utf-8>
113
- <title>Api information</title>
114
- </head>
115
- <body>
116
- Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
117
- </body>
118
- </html>
119
- """
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
  import os
15
  import sys
 
 
 
 
 
16
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17
  sys.path.append('{}/../../..'.format(ROOT_DIR))
18
  sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
19
+ import argparse
 
 
 
 
20
  import logging
21
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
22
+ from fastapi import FastAPI, UploadFile, Form, File
23
+ from fastapi.responses import StreamingResponse
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ import uvicorn
26
+ import numpy as np
27
+ from cosyvoice.cli.cosyvoice import CosyVoice
28
+ from cosyvoice.utils.file_utils import load_wav
29
 
30
+ app = FastAPI()
31
+ # set cross region allowance
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
  allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def generate_data(model_output):
40
+ for i in model_output:
41
+ tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
42
+ yield tts_audio
43
 
44
+ @app.get("/inference_sft")
45
+ async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
46
+ model_output = cosyvoice.inference_sft(tts_text, spk_id)
47
+ return StreamingResponse(generate_data(model_output))
 
 
 
48
 
49
+ @app.get("/inference_zero_shot")
50
+ async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
51
+ prompt_speech_16k = load_wav(prompt_wav.file, 16000)
52
+ model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
53
+ return StreamingResponse(generate_data(model_output))
54
 
55
+ @app.get("/inference_cross_lingual")
56
+ async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
57
+ prompt_speech_16k = load_wav(prompt_wav.file, 16000)
58
+ model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
59
+ return StreamingResponse(generate_data(model_output))
 
 
 
60
 
61
+ @app.get("/inference_instruct")
62
+ async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
63
+ model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
64
+ return StreamingResponse(generate_data(model_output))
65
 
66
+ if __name__=='__main__':
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument('--port',
69
+ type=int,
70
+ default=50000)
71
+ parser.add_argument('--model_dir',
72
+ type=str,
73
+ default='iic/CosyVoice-300M',
74
+ help='local path or modelscope repo id')
75
+ args = parser.parse_args()
76
+ cosyvoice = CosyVoice(args.model_dir)
77
+ uvicorn.run(app, host="127.0.0.1", port=args.port)