Spaces:
Running
on
Zero
Running
on
Zero
Merge pull request #353 from FunAudioLLM/inference_streaming
Browse files- README.md +1 -1
- cosyvoice/bin/export_jit.py +8 -1
- cosyvoice/bin/export_onnx.py +109 -0
- cosyvoice/bin/export_trt.py +0 -8
- cosyvoice/cli/cosyvoice.py +10 -6
- cosyvoice/cli/model.py +15 -3
- cosyvoice/flow/decoder.py +1 -1
- cosyvoice/flow/flow.py +1 -1
- cosyvoice/flow/flow_matching.py +19 -3
- cosyvoice/hifigan/generator.py +1 -1
- examples/libritts/cosyvoice/run.sh +6 -0
- examples/magicdata-read/cosyvoice/run.sh +6 -0
- requirements.txt +2 -0
- runtime/python/fastapi/client.py +46 -34
- runtime/python/fastapi/server.py +61 -103
README.md
CHANGED
@@ -167,7 +167,7 @@ docker build -t cosyvoice:v1.0 .
|
|
167 |
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
168 |
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
169 |
# for fastapi usage
|
170 |
-
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi &&
|
171 |
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
172 |
```
|
173 |
|
|
|
167 |
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
168 |
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
169 |
# for fastapi usage
|
170 |
+
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
171 |
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
172 |
```
|
173 |
|
cosyvoice/bin/export_jit.py
CHANGED
@@ -44,7 +44,7 @@ def main():
|
|
44 |
torch._C._jit_set_profiling_mode(False)
|
45 |
torch._C._jit_set_profiling_executor(False)
|
46 |
|
47 |
-
cosyvoice = CosyVoice(args.model_dir, load_jit=False,
|
48 |
|
49 |
# 1. export llm text_encoder
|
50 |
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
@@ -60,5 +60,12 @@ def main():
|
|
60 |
script = torch.jit.optimize_for_inference(script)
|
61 |
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if __name__ == '__main__':
|
64 |
main()
|
|
|
44 |
torch._C._jit_set_profiling_mode(False)
|
45 |
torch._C._jit_set_profiling_executor(False)
|
46 |
|
47 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
48 |
|
49 |
# 1. export llm text_encoder
|
50 |
llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
|
|
|
60 |
script = torch.jit.optimize_for_inference(script)
|
61 |
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
62 |
|
63 |
+
# 3. export flow encoder
|
64 |
+
flow_encoder = cosyvoice.model.flow.encoder
|
65 |
+
script = torch.jit.script(flow_encoder)
|
66 |
+
script = torch.jit.freeze(script)
|
67 |
+
script = torch.jit.optimize_for_inference(script)
|
68 |
+
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
69 |
+
|
70 |
if __name__ == '__main__':
|
71 |
main()
|
cosyvoice/bin/export_onnx.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, [email protected])
|
2 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from __future__ import print_function
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import logging
|
20 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
+
import os
|
22 |
+
import sys
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
26 |
+
import onnxruntime
|
27 |
+
import random
|
28 |
+
import torch
|
29 |
+
from tqdm import tqdm
|
30 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
31 |
+
|
32 |
+
|
33 |
+
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
34 |
+
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
35 |
+
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
36 |
+
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
37 |
+
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
38 |
+
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
39 |
+
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
40 |
+
return x, mask, mu, t, spks, cond
|
41 |
+
|
42 |
+
|
43 |
+
def get_args():
|
44 |
+
parser = argparse.ArgumentParser(description='export your model for deployment')
|
45 |
+
parser.add_argument('--model_dir',
|
46 |
+
type=str,
|
47 |
+
default='pretrained_models/CosyVoice-300M',
|
48 |
+
help='local path')
|
49 |
+
args = parser.parse_args()
|
50 |
+
print(args)
|
51 |
+
return args
|
52 |
+
|
53 |
+
def main():
|
54 |
+
args = get_args()
|
55 |
+
logging.basicConfig(level=logging.DEBUG,
|
56 |
+
format='%(asctime)s %(levelname)s %(message)s')
|
57 |
+
|
58 |
+
cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
|
59 |
+
|
60 |
+
# 1. export flow decoder estimator
|
61 |
+
estimator = cosyvoice.model.flow.decoder.estimator
|
62 |
+
|
63 |
+
device = cosyvoice.model.device
|
64 |
+
batch_size, seq_len = 1, 256
|
65 |
+
out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
|
66 |
+
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
67 |
+
torch.onnx.export(
|
68 |
+
estimator,
|
69 |
+
(x, mask, mu, t, spks, cond),
|
70 |
+
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
71 |
+
export_params=True,
|
72 |
+
opset_version=18,
|
73 |
+
do_constant_folding=True,
|
74 |
+
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
75 |
+
output_names=['estimator_out'],
|
76 |
+
dynamic_axes={
|
77 |
+
'x': {0: 'batch_size', 2: 'seq_len'},
|
78 |
+
'mask': {0: 'batch_size', 2: 'seq_len'},
|
79 |
+
'mu': {0: 'batch_size', 2: 'seq_len'},
|
80 |
+
'cond': {0: 'batch_size', 2: 'seq_len'},
|
81 |
+
't': {0: 'batch_size'},
|
82 |
+
'spks': {0: 'batch_size'},
|
83 |
+
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
|
84 |
+
}
|
85 |
+
)
|
86 |
+
|
87 |
+
# 2. test computation consistency
|
88 |
+
option = onnxruntime.SessionOptions()
|
89 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
90 |
+
option.intra_op_num_threads = 1
|
91 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
92 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
|
93 |
+
|
94 |
+
for _ in tqdm(range(10)):
|
95 |
+
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
96 |
+
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
97 |
+
ort_inputs = {
|
98 |
+
'x': x.cpu().numpy(),
|
99 |
+
'mask': mask.cpu().numpy(),
|
100 |
+
'mu': mu.cpu().numpy(),
|
101 |
+
't': t.cpu().numpy(),
|
102 |
+
'spks': spks.cpu().numpy(),
|
103 |
+
'cond': cond.cpu().numpy()
|
104 |
+
}
|
105 |
+
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
106 |
+
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
main()
|
cosyvoice/bin/export_trt.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# TODO 跟export_jit一样的逻辑,完成flow部分的estimator的onnx导出。
|
2 |
-
# tensorrt的安装方式,再这里写一下步骤提示如下,如果没有安装,那么不要执行这个脚本,提示用户先安装,不给选择
|
3 |
-
try:
|
4 |
-
import tensorrt
|
5 |
-
except ImportError:
|
6 |
-
print('step1, 下载\n step2. 解压,安装whl,')
|
7 |
-
# 安装命令里tensosrt的根目录用环境变量导入,比如os.environ['tensorrt_root_dir']/bin/exetrace,然后python里subprocess里执行导出命令
|
8 |
-
# 后面我会在run.sh里写好执行命令 tensorrt_root_dir=xxxx python cosyvoice/bin/export_trt.py --model_dir xxx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cosyvoice/cli/cosyvoice.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import time
|
|
|
16 |
from hyperpyyaml import load_hyperpyyaml
|
17 |
from modelscope import snapshot_download
|
18 |
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
@@ -21,7 +22,7 @@ from cosyvoice.utils.file_utils import logging
|
|
21 |
|
22 |
class CosyVoice:
|
23 |
|
24 |
-
def __init__(self, model_dir, load_jit=True):
|
25 |
instruct = True if '-Instruct' in model_dir else False
|
26 |
self.model_dir = model_dir
|
27 |
if not os.path.exists(model_dir):
|
@@ -41,7 +42,10 @@ class CosyVoice:
|
|
41 |
'{}/hift.pt'.format(model_dir))
|
42 |
if load_jit:
|
43 |
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
44 |
-
'{}/llm.llm.fp16.zip'.format(model_dir)
|
|
|
|
|
|
|
45 |
del configs
|
46 |
|
47 |
def list_avaliable_spks(self):
|
@@ -49,7 +53,7 @@ class CosyVoice:
|
|
49 |
return spks
|
50 |
|
51 |
def inference_sft(self, tts_text, spk_id, stream=False):
|
52 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
53 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
54 |
start_time = time.time()
|
55 |
logging.info('synthesis text {}'.format(i))
|
@@ -61,7 +65,7 @@ class CosyVoice:
|
|
61 |
|
62 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
63 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
64 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
65 |
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
66 |
start_time = time.time()
|
67 |
logging.info('synthesis text {}'.format(i))
|
@@ -74,7 +78,7 @@ class CosyVoice:
|
|
74 |
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
75 |
if self.frontend.instruct is True:
|
76 |
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
77 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
78 |
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
79 |
start_time = time.time()
|
80 |
logging.info('synthesis text {}'.format(i))
|
@@ -88,7 +92,7 @@ class CosyVoice:
|
|
88 |
if self.frontend.instruct is False:
|
89 |
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
90 |
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
91 |
-
for i in self.frontend.text_normalize(tts_text, split=True):
|
92 |
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
93 |
start_time = time.time()
|
94 |
logging.info('synthesis text {}'.format(i))
|
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import time
|
16 |
+
from tqdm import tqdm
|
17 |
from hyperpyyaml import load_hyperpyyaml
|
18 |
from modelscope import snapshot_download
|
19 |
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
|
22 |
|
23 |
class CosyVoice:
|
24 |
|
25 |
+
def __init__(self, model_dir, load_jit=True, load_onnx=True):
|
26 |
instruct = True if '-Instruct' in model_dir else False
|
27 |
self.model_dir = model_dir
|
28 |
if not os.path.exists(model_dir):
|
|
|
42 |
'{}/hift.pt'.format(model_dir))
|
43 |
if load_jit:
|
44 |
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
45 |
+
'{}/llm.llm.fp16.zip'.format(model_dir),
|
46 |
+
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
47 |
+
if load_onnx:
|
48 |
+
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
49 |
del configs
|
50 |
|
51 |
def list_avaliable_spks(self):
|
|
|
53 |
return spks
|
54 |
|
55 |
def inference_sft(self, tts_text, spk_id, stream=False):
|
56 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
57 |
model_input = self.frontend.frontend_sft(i, spk_id)
|
58 |
start_time = time.time()
|
59 |
logging.info('synthesis text {}'.format(i))
|
|
|
65 |
|
66 |
def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
|
67 |
prompt_text = self.frontend.text_normalize(prompt_text, split=False)
|
68 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
69 |
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
|
70 |
start_time = time.time()
|
71 |
logging.info('synthesis text {}'.format(i))
|
|
|
78 |
def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
|
79 |
if self.frontend.instruct is True:
|
80 |
raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
|
81 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
82 |
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
|
83 |
start_time = time.time()
|
84 |
logging.info('synthesis text {}'.format(i))
|
|
|
92 |
if self.frontend.instruct is False:
|
93 |
raise ValueError('{} do not support instruct inference'.format(self.model_dir))
|
94 |
instruct_text = self.frontend.text_normalize(instruct_text, split=False)
|
95 |
+
for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
|
96 |
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
97 |
start_time = time.time()
|
98 |
logging.info('synthesis text {}'.format(i))
|
cosyvoice/cli/model.py
CHANGED
@@ -18,7 +18,7 @@ import time
|
|
18 |
from contextlib import nullcontext
|
19 |
import uuid
|
20 |
from cosyvoice.utils.common import fade_in_out
|
21 |
-
|
22 |
|
23 |
class CosyVoiceModel:
|
24 |
|
@@ -60,11 +60,22 @@ class CosyVoiceModel:
|
|
60 |
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
61 |
self.hift.to(self.device).eval()
|
62 |
|
63 |
-
def load_jit(self, llm_text_encoder_model, llm_llm_model):
|
64 |
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
65 |
self.llm.text_encoder = llm_text_encoder
|
66 |
llm_llm = torch.jit.load(llm_llm_model)
|
67 |
self.llm.llm = llm_llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
70 |
with self.llm_context:
|
@@ -169,4 +180,5 @@ class CosyVoiceModel:
|
|
169 |
self.llm_end_dict.pop(this_uuid)
|
170 |
self.mel_overlap_dict.pop(this_uuid)
|
171 |
self.hift_cache_dict.pop(this_uuid)
|
172 |
-
torch.cuda.
|
|
|
|
18 |
from contextlib import nullcontext
|
19 |
import uuid
|
20 |
from cosyvoice.utils.common import fade_in_out
|
21 |
+
import numpy as np
|
22 |
|
23 |
class CosyVoiceModel:
|
24 |
|
|
|
60 |
self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
|
61 |
self.hift.to(self.device).eval()
|
62 |
|
63 |
+
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
64 |
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
|
65 |
self.llm.text_encoder = llm_text_encoder
|
66 |
llm_llm = torch.jit.load(llm_llm_model)
|
67 |
self.llm.llm = llm_llm
|
68 |
+
flow_encoder = torch.jit.load(flow_encoder_model)
|
69 |
+
self.flow.encoder = flow_encoder
|
70 |
+
|
71 |
+
def load_onnx(self, flow_decoder_estimator_model):
|
72 |
+
import onnxruntime
|
73 |
+
option = onnxruntime.SessionOptions()
|
74 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
75 |
+
option.intra_op_num_threads = 1
|
76 |
+
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
77 |
+
del self.flow.decoder.estimator
|
78 |
+
self.flow.decoder.estimator = onnxruntime.InferenceSession(flow_decoder_estimator_model, sess_options=option, providers=providers)
|
79 |
|
80 |
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
81 |
with self.llm_context:
|
|
|
180 |
self.llm_end_dict.pop(this_uuid)
|
181 |
self.mel_overlap_dict.pop(this_uuid)
|
182 |
self.hift_cache_dict.pop(this_uuid)
|
183 |
+
if torch.cuda.is_available():
|
184 |
+
torch.cuda.synchronize()
|
cosyvoice/flow/decoder.py
CHANGED
@@ -159,7 +159,7 @@ class ConditionalDecoder(nn.Module):
|
|
159 |
_type_: _description_
|
160 |
"""
|
161 |
|
162 |
-
t = self.time_embeddings(t)
|
163 |
t = self.time_mlp(t)
|
164 |
|
165 |
x = pack([x, mu], "b * t")[0]
|
|
|
159 |
_type_: _description_
|
160 |
"""
|
161 |
|
162 |
+
t = self.time_embeddings(t).to(t.dtype)
|
163 |
t = self.time_mlp(t)
|
164 |
|
165 |
x = pack([x, mu], "b * t")[0]
|
cosyvoice/flow/flow.py
CHANGED
@@ -113,7 +113,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
113 |
# concat text and prompt_text
|
114 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
115 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
116 |
-
mask = (~make_pad_mask(token_len)).
|
117 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
118 |
|
119 |
# text encode
|
|
|
113 |
# concat text and prompt_text
|
114 |
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
115 |
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
116 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
117 |
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
118 |
|
119 |
# text encode
|
cosyvoice/flow/flow_matching.py
CHANGED
@@ -50,7 +50,7 @@ class ConditionalCFM(BASECFM):
|
|
50 |
shape: (batch_size, n_feats, mel_timesteps)
|
51 |
"""
|
52 |
z = torch.randn_like(mu) * temperature
|
53 |
-
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
54 |
if self.t_scheduler == 'cosine':
|
55 |
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
56 |
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
@@ -71,16 +71,17 @@ class ConditionalCFM(BASECFM):
|
|
71 |
cond: Not used but kept for future purposes
|
72 |
"""
|
73 |
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
|
|
74 |
|
75 |
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
76 |
# Or in future might add like a return_all_steps flag
|
77 |
sol = []
|
78 |
|
79 |
for step in range(1, len(t_span)):
|
80 |
-
dphi_dt = self.
|
81 |
# Classifier-Free Guidance inference introduced in VoiceBox
|
82 |
if self.inference_cfg_rate > 0:
|
83 |
-
cfg_dphi_dt = self.
|
84 |
x, mask,
|
85 |
torch.zeros_like(mu), t,
|
86 |
torch.zeros_like(spks) if spks is not None else None,
|
@@ -96,6 +97,21 @@ class ConditionalCFM(BASECFM):
|
|
96 |
|
97 |
return sol[-1]
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
100 |
"""Computes diffusion loss
|
101 |
|
|
|
50 |
shape: (batch_size, n_feats, mel_timesteps)
|
51 |
"""
|
52 |
z = torch.randn_like(mu) * temperature
|
53 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
54 |
if self.t_scheduler == 'cosine':
|
55 |
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
56 |
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
|
|
|
71 |
cond: Not used but kept for future purposes
|
72 |
"""
|
73 |
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
74 |
+
t = t.unsqueeze(dim=0)
|
75 |
|
76 |
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
77 |
# Or in future might add like a return_all_steps flag
|
78 |
sol = []
|
79 |
|
80 |
for step in range(1, len(t_span)):
|
81 |
+
dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
|
82 |
# Classifier-Free Guidance inference introduced in VoiceBox
|
83 |
if self.inference_cfg_rate > 0:
|
84 |
+
cfg_dphi_dt = self.forward_estimator(
|
85 |
x, mask,
|
86 |
torch.zeros_like(mu), t,
|
87 |
torch.zeros_like(spks) if spks is not None else None,
|
|
|
97 |
|
98 |
return sol[-1]
|
99 |
|
100 |
+
def forward_estimator(self, x, mask, mu, t, spks, cond):
|
101 |
+
if isinstance(self.estimator, torch.nn.Module):
|
102 |
+
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
103 |
+
else:
|
104 |
+
ort_inputs = {
|
105 |
+
'x': x.cpu().numpy(),
|
106 |
+
'mask': mask.cpu().numpy(),
|
107 |
+
'mu': mu.cpu().numpy(),
|
108 |
+
't': t.cpu().numpy(),
|
109 |
+
'spks': spks.cpu().numpy(),
|
110 |
+
'cond': cond.cpu().numpy()
|
111 |
+
}
|
112 |
+
output = self.estimator.run(None, ort_inputs)[0]
|
113 |
+
return torch.tensor(output, dtype=x.dtype, device=x.device)
|
114 |
+
|
115 |
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
116 |
"""Computes diffusion loss
|
117 |
|
cosyvoice/hifigan/generator.py
CHANGED
@@ -340,7 +340,7 @@ class HiFTGenerator(nn.Module):
|
|
340 |
s = self._f02source(f0)
|
341 |
|
342 |
# use cache_source to avoid glitch
|
343 |
-
if cache_source.shape[2]
|
344 |
s[:, :, :cache_source.shape[2]] = cache_source
|
345 |
|
346 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
|
|
340 |
s = self._f02source(f0)
|
341 |
|
342 |
# use cache_source to avoid glitch
|
343 |
+
if cache_source.shape[2] != 0:
|
344 |
s[:, :, :cache_source.shape[2]] = cache_source
|
345 |
|
346 |
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
examples/libritts/cosyvoice/run.sh
CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
fi
|
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
105 |
+
fi
|
106 |
+
|
107 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
108 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
109 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
110 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
111 |
fi
|
examples/magicdata-read/cosyvoice/run.sh
CHANGED
@@ -102,4 +102,10 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
fi
|
|
|
102 |
--deepspeed_config ./conf/ds_stage2.json \
|
103 |
--deepspeed.save_states model+optimizer
|
104 |
done
|
105 |
+
fi
|
106 |
+
|
107 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
108 |
+
echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
|
109 |
+
python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
|
110 |
+
python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
|
111 |
fi
|
requirements.txt
CHANGED
@@ -15,6 +15,7 @@ matplotlib==3.7.5
|
|
15 |
modelscope==1.15.0
|
16 |
networkx==3.1
|
17 |
omegaconf==2.3.0
|
|
|
18 |
onnxruntime-gpu==1.16.0; sys_platform == 'linux'
|
19 |
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
|
20 |
openai-whisper==20231117
|
@@ -25,6 +26,7 @@ soundfile==0.12.1
|
|
25 |
tensorboard==2.14.0
|
26 |
torch==2.0.1
|
27 |
torchaudio==2.0.2
|
|
|
28 |
wget==3.2
|
29 |
fastapi==0.111.0
|
30 |
fastapi-cli==0.0.4
|
|
|
15 |
modelscope==1.15.0
|
16 |
networkx==3.1
|
17 |
omegaconf==2.3.0
|
18 |
+
onnx==1.16.0
|
19 |
onnxruntime-gpu==1.16.0; sys_platform == 'linux'
|
20 |
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows'
|
21 |
openai-whisper==20231117
|
|
|
26 |
tensorboard==2.14.0
|
27 |
torch==2.0.1
|
28 |
torchaudio==2.0.2
|
29 |
+
uvicorn==0.30.0
|
30 |
wget==3.2
|
31 |
fastapi==0.111.0
|
32 |
fastapi-cli==0.0.4
|
runtime/python/fastapi/client.py
CHANGED
@@ -1,56 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import argparse
|
2 |
import logging
|
3 |
import requests
|
|
|
|
|
|
|
4 |
|
5 |
-
def saveResponse(path, response):
|
6 |
-
# 以二进制写入模式打开文件
|
7 |
-
with open(path, 'wb') as file:
|
8 |
-
# 将响应的二进制内容写入文件
|
9 |
-
file.write(response.content)
|
10 |
|
11 |
def main():
|
12 |
-
|
13 |
if args.mode == 'sft':
|
14 |
-
|
15 |
-
|
16 |
-
'
|
17 |
-
'role': args.spk_id
|
18 |
}
|
19 |
-
response = requests.request("
|
20 |
-
saveResponse(args.tts_wav, response)
|
21 |
elif args.mode == 'zero_shot':
|
22 |
-
|
23 |
-
|
24 |
-
'
|
25 |
-
'prompt': args.prompt_text
|
26 |
}
|
27 |
-
files=[('
|
28 |
-
response = requests.request("
|
29 |
-
saveResponse(args.tts_wav, response)
|
30 |
elif args.mode == 'cross_lingual':
|
31 |
-
|
32 |
-
|
33 |
-
'tts': args.tts_text,
|
34 |
}
|
35 |
-
files=[('
|
36 |
-
response = requests.request("
|
37 |
-
saveResponse(args.tts_wav, response)
|
38 |
else:
|
39 |
-
url = api + "/api/inference/instruct"
|
40 |
payload = {
|
41 |
-
'
|
42 |
-
'
|
43 |
-
'
|
44 |
}
|
45 |
-
response = requests.request("
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
if __name__ == "__main__":
|
50 |
parser = argparse.ArgumentParser()
|
51 |
-
parser.add_argument('--
|
52 |
type=str,
|
53 |
-
default='
|
|
|
|
|
|
|
54 |
parser.add_argument('--mode',
|
55 |
default='sft',
|
56 |
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
import argparse
|
15 |
import logging
|
16 |
import requests
|
17 |
+
import torch
|
18 |
+
import torchaudio
|
19 |
+
import numpy as np
|
20 |
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def main():
|
23 |
+
url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode)
|
24 |
if args.mode == 'sft':
|
25 |
+
payload = {
|
26 |
+
'tts_text': args.tts_text,
|
27 |
+
'spk_id': args.spk_id
|
|
|
28 |
}
|
29 |
+
response = requests.request("GET", url, data=payload, stream=True)
|
|
|
30 |
elif args.mode == 'zero_shot':
|
31 |
+
payload = {
|
32 |
+
'tts_text': args.tts_text,
|
33 |
+
'prompt_text': args.prompt_text
|
|
|
34 |
}
|
35 |
+
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
|
36 |
+
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
|
|
37 |
elif args.mode == 'cross_lingual':
|
38 |
+
payload = {
|
39 |
+
'tts_text': args.tts_text,
|
|
|
40 |
}
|
41 |
+
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
42 |
+
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
|
|
43 |
else:
|
|
|
44 |
payload = {
|
45 |
+
'tts_text': args.tts_text,
|
46 |
+
'spk_id': args.spk_id,
|
47 |
+
'instruct_text': args.instruct_text
|
48 |
}
|
49 |
+
response = requests.request("GET", url, data=payload, stream=True)
|
50 |
+
tts_audio = b''
|
51 |
+
for r in response.iter_content(chunk_size=16000):
|
52 |
+
tts_audio += r
|
53 |
+
tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
|
54 |
+
logging.info('save response to {}'.format(args.tts_wav))
|
55 |
+
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
56 |
+
logging.info('get response')
|
57 |
|
58 |
if __name__ == "__main__":
|
59 |
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument('--host',
|
61 |
type=str,
|
62 |
+
default='0.0.0.0')
|
63 |
+
parser.add_argument('--port',
|
64 |
+
type=int,
|
65 |
+
default='50000')
|
66 |
parser.add_argument('--mode',
|
67 |
default='sft',
|
68 |
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
runtime/python/fastapi/server.py
CHANGED
@@ -1,119 +1,77 @@
|
|
1 |
-
#
|
2 |
-
#
|
3 |
-
#
|
4 |
-
#
|
5 |
-
#
|
6 |
-
#
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
import os
|
9 |
import sys
|
10 |
-
import io,time
|
11 |
-
from fastapi import FastAPI, Response, File, UploadFile, Form
|
12 |
-
from fastapi.responses import HTMLResponse
|
13 |
-
from fastapi.middleware.cors import CORSMiddleware #引入 CORS中间件模块
|
14 |
-
from contextlib import asynccontextmanager
|
15 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
16 |
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
17 |
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
18 |
-
|
19 |
-
from cosyvoice.utils.file_utils import load_wav
|
20 |
-
import numpy as np
|
21 |
-
import torch
|
22 |
-
import torchaudio
|
23 |
import logging
|
24 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
@asynccontextmanager
|
30 |
-
async def lifespan(app: FastAPI):
|
31 |
-
model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
|
32 |
-
if model_dir:
|
33 |
-
logging.info("MODEL_DIR is {}", model_dir)
|
34 |
-
app.cosyvoice = CosyVoice(model_dir)
|
35 |
-
# sft usage
|
36 |
-
logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
|
37 |
-
else:
|
38 |
-
raise LaunchFailed("MODEL_DIR environment must set")
|
39 |
-
yield
|
40 |
-
|
41 |
-
app = FastAPI(lifespan=lifespan)
|
42 |
-
|
43 |
-
#设置允许访问的域名
|
44 |
-
origins = ["*"] #"*",即为所有,也可以改为允许的特定ip。
|
45 |
app.add_middleware(
|
46 |
-
CORSMiddleware,
|
47 |
-
allow_origins=
|
48 |
allow_credentials=True,
|
49 |
-
allow_methods=["*"],
|
50 |
-
allow_headers=["*"])
|
51 |
-
|
52 |
-
def buildResponse(output):
|
53 |
-
buffer = io.BytesIO()
|
54 |
-
torchaudio.save(buffer, output, 22050, format="wav")
|
55 |
-
buffer.seek(0)
|
56 |
-
return Response(content=buffer.read(-1), media_type="audio/wav")
|
57 |
-
|
58 |
-
@app.post("/api/inference/sft")
|
59 |
-
@app.get("/api/inference/sft")
|
60 |
-
async def sft(tts: str = Form(), role: str = Form()):
|
61 |
-
start = time.process_time()
|
62 |
-
output = app.cosyvoice.inference_sft(tts, role)
|
63 |
-
end = time.process_time()
|
64 |
-
logging.info("infer time is {} seconds", end-start)
|
65 |
-
return buildResponse(output['tts_speech'])
|
66 |
-
|
67 |
-
@app.post("/api/inference/zero-shot")
|
68 |
-
async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
|
69 |
-
start = time.process_time()
|
70 |
-
prompt_speech = load_wav(audio.file, 16000)
|
71 |
-
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
72 |
-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
73 |
-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
@app.
|
81 |
-
async def
|
82 |
-
|
83 |
-
|
84 |
-
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
85 |
-
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
86 |
-
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
92 |
|
93 |
-
@app.
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
end = time.process_time()
|
99 |
-
logging.info("infer time is {} seconds", end-start)
|
100 |
-
return buildResponse(output['tts_speech'])
|
101 |
|
102 |
-
@app.get("/
|
103 |
-
async def
|
104 |
-
|
|
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
</html>
|
119 |
-
"""
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
|
|
|
|
|
|
|
|
|
|
16 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
18 |
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
19 |
+
import argparse
|
|
|
|
|
|
|
|
|
20 |
import logging
|
21 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
22 |
+
from fastapi import FastAPI, UploadFile, Form, File
|
23 |
+
from fastapi.responses import StreamingResponse
|
24 |
+
from fastapi.middleware.cors import CORSMiddleware
|
25 |
+
import uvicorn
|
26 |
+
import numpy as np
|
27 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
28 |
+
from cosyvoice.utils.file_utils import load_wav
|
29 |
|
30 |
+
app = FastAPI()
|
31 |
+
# set cross region allowance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
app.add_middleware(
|
33 |
+
CORSMiddleware,
|
34 |
+
allow_origins=["*"],
|
35 |
allow_credentials=True,
|
36 |
+
allow_methods=["*"],
|
37 |
+
allow_headers=["*"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def generate_data(model_output):
|
40 |
+
for i in model_output:
|
41 |
+
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
42 |
+
yield tts_audio
|
43 |
|
44 |
+
@app.get("/inference_sft")
|
45 |
+
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
46 |
+
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
47 |
+
return StreamingResponse(generate_data(model_output))
|
|
|
|
|
|
|
48 |
|
49 |
+
@app.get("/inference_zero_shot")
|
50 |
+
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
51 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
52 |
+
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
53 |
+
return StreamingResponse(generate_data(model_output))
|
54 |
|
55 |
+
@app.get("/inference_cross_lingual")
|
56 |
+
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
57 |
+
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
58 |
+
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
59 |
+
return StreamingResponse(generate_data(model_output))
|
|
|
|
|
|
|
60 |
|
61 |
+
@app.get("/inference_instruct")
|
62 |
+
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
63 |
+
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
64 |
+
return StreamingResponse(generate_data(model_output))
|
65 |
|
66 |
+
if __name__=='__main__':
|
67 |
+
parser = argparse.ArgumentParser()
|
68 |
+
parser.add_argument('--port',
|
69 |
+
type=int,
|
70 |
+
default=50000)
|
71 |
+
parser.add_argument('--model_dir',
|
72 |
+
type=str,
|
73 |
+
default='iic/CosyVoice-300M',
|
74 |
+
help='local path or modelscope repo id')
|
75 |
+
args = parser.parse_args()
|
76 |
+
cosyvoice = CosyVoice(args.model_dir)
|
77 |
+
uvicorn.run(app, host="127.0.0.1", port=args.port)
|
|
|
|