Spaces:
Running
on
Zero
Running
on
Zero
Merge pull request #56 from iflamed/fastapi
Browse filesAdd Fastapi server to serve TTS and download script
- README.md +1 -1
- requirements.txt +2 -0
- runtime/python/fastapi_client.py +78 -0
- runtime/python/fastapi_server.py +109 -0
README.md
CHANGED
@@ -152,4 +152,4 @@ You can also scan the QR code to join our official Dingding chat group.
|
|
152 |
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
153 |
|
154 |
## Disclaimer
|
155 |
-
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
|
|
152 |
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
153 |
|
154 |
## Disclaimer
|
155 |
+
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
requirements.txt
CHANGED
@@ -26,4 +26,6 @@ tensorboard==2.14.0
|
|
26 |
torch==2.0.1
|
27 |
torchaudio==2.0.2
|
28 |
wget==3.2
|
|
|
|
|
29 |
WeTextProcessing==1.0.3
|
|
|
26 |
torch==2.0.1
|
27 |
torchaudio==2.0.2
|
28 |
wget==3.2
|
29 |
+
fastapi==0.111.0
|
30 |
+
fastapi-cli==0.0.4
|
31 |
WeTextProcessing==1.0.3
|
runtime/python/fastapi_client.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import requests
|
4 |
+
|
5 |
+
def saveResponse(path, response):
|
6 |
+
# 以二进制写入模式打开文件
|
7 |
+
with open(path, 'wb') as file:
|
8 |
+
# 将响应的二进制内容写入文件
|
9 |
+
file.write(response.content)
|
10 |
+
|
11 |
+
def main():
|
12 |
+
api = args.api_base
|
13 |
+
if args.mode == 'sft':
|
14 |
+
url = api + "/api/inference/sft"
|
15 |
+
payload={
|
16 |
+
'tts': args.tts_text,
|
17 |
+
'role': args.spk_id
|
18 |
+
}
|
19 |
+
response = requests.request("POST", url, data=payload)
|
20 |
+
saveResponse(args.tts_wav, response)
|
21 |
+
elif args.mode == 'zero_shot':
|
22 |
+
url = api + "/api/inference/zero-shot"
|
23 |
+
payload={
|
24 |
+
'tts': args.tts_text,
|
25 |
+
'prompt': args.prompt_text
|
26 |
+
}
|
27 |
+
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
28 |
+
response = requests.request("POST", url, data=payload, files=files)
|
29 |
+
saveResponse(args.tts_wav, response)
|
30 |
+
elif args.mode == 'cross_lingual':
|
31 |
+
url = api + "/api/inference/cross-lingual"
|
32 |
+
payload={
|
33 |
+
'tts': args.tts_text,
|
34 |
+
}
|
35 |
+
files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
36 |
+
response = requests.request("POST", url, data=payload, files=files)
|
37 |
+
saveResponse(args.tts_wav, response)
|
38 |
+
else:
|
39 |
+
url = api + "/api/inference/instruct"
|
40 |
+
payload = {
|
41 |
+
'tts': args.tts_text,
|
42 |
+
'role': args.spk_id,
|
43 |
+
'instruct': args.instruct_text
|
44 |
+
}
|
45 |
+
response = requests.request("POST", url, data=payload)
|
46 |
+
saveResponse(args.tts_wav, response)
|
47 |
+
logging.info("Response save to {}", args.tts_wav)
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
parser = argparse.ArgumentParser()
|
51 |
+
parser.add_argument('--api_base',
|
52 |
+
type=str,
|
53 |
+
default='http://127.0.0.1:6006')
|
54 |
+
parser.add_argument('--mode',
|
55 |
+
default='sft',
|
56 |
+
choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
|
57 |
+
help='request mode')
|
58 |
+
parser.add_argument('--tts_text',
|
59 |
+
type=str,
|
60 |
+
default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
|
61 |
+
parser.add_argument('--spk_id',
|
62 |
+
type=str,
|
63 |
+
default='中文女')
|
64 |
+
parser.add_argument('--prompt_text',
|
65 |
+
type=str,
|
66 |
+
default='希望你以后能够做的比我还好呦。')
|
67 |
+
parser.add_argument('--prompt_wav',
|
68 |
+
type=str,
|
69 |
+
default='../../zero_shot_prompt.wav')
|
70 |
+
parser.add_argument('--instruct_text',
|
71 |
+
type=str,
|
72 |
+
default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
|
73 |
+
parser.add_argument('--tts_wav',
|
74 |
+
type=str,
|
75 |
+
default='demo.wav')
|
76 |
+
args = parser.parse_args()
|
77 |
+
prompt_sr, target_sr = 16000, 22050
|
78 |
+
main()
|
runtime/python/fastapi_server.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set inference model
|
2 |
+
# export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
|
3 |
+
# For development
|
4 |
+
# fastapi dev --port 6006 fastapi_server.py
|
5 |
+
# For production deployment
|
6 |
+
# fastapi run --port 6006 fastapi_server.py
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import io,time
|
11 |
+
from fastapi import FastAPI, Response, File, UploadFile, Form
|
12 |
+
from fastapi.responses import HTMLResponse
|
13 |
+
from contextlib import asynccontextmanager
|
14 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
15 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
16 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
17 |
+
from cosyvoice.cli.cosyvoice import CosyVoice
|
18 |
+
from cosyvoice.utils.file_utils import load_wav
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torchaudio
|
22 |
+
import logging
|
23 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
24 |
+
|
25 |
+
class LaunchFailed(Exception):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@asynccontextmanager
|
29 |
+
async def lifespan(app: FastAPI):
|
30 |
+
model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
|
31 |
+
if model_dir:
|
32 |
+
logging.info("MODEL_DIR is {}", model_dir)
|
33 |
+
app.cosyvoice = CosyVoice('../../'+model_dir)
|
34 |
+
# sft usage
|
35 |
+
logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
|
36 |
+
else:
|
37 |
+
raise LaunchFailed("MODEL_DIR environment must set")
|
38 |
+
yield
|
39 |
+
|
40 |
+
app = FastAPI(lifespan=lifespan)
|
41 |
+
|
42 |
+
def buildResponse(output):
|
43 |
+
buffer = io.BytesIO()
|
44 |
+
torchaudio.save(buffer, output, 22050, format="wav")
|
45 |
+
buffer.seek(0)
|
46 |
+
return Response(content=buffer.read(-1), media_type="audio/wav")
|
47 |
+
|
48 |
+
@app.post("/api/inference/sft")
|
49 |
+
@app.get("/api/inference/sft")
|
50 |
+
async def sft(tts: str = Form(), role: str = Form()):
|
51 |
+
start = time.process_time()
|
52 |
+
output = app.cosyvoice.inference_sft(tts, role)
|
53 |
+
end = time.process_time()
|
54 |
+
logging.info("infer time is {} seconds", end-start)
|
55 |
+
return buildResponse(output['tts_speech'])
|
56 |
+
|
57 |
+
@app.post("/api/inference/zero-shot")
|
58 |
+
async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
|
59 |
+
start = time.process_time()
|
60 |
+
prompt_speech = load_wav(audio.file, 16000)
|
61 |
+
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
62 |
+
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
63 |
+
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
64 |
+
|
65 |
+
output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
|
66 |
+
end = time.process_time()
|
67 |
+
logging.info("infer time is {} seconds", end-start)
|
68 |
+
return buildResponse(output['tts_speech'])
|
69 |
+
|
70 |
+
@app.post("/api/inference/cross-lingual")
|
71 |
+
async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
|
72 |
+
start = time.process_time()
|
73 |
+
prompt_speech = load_wav(audio.file, 16000)
|
74 |
+
prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
|
75 |
+
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
76 |
+
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
77 |
+
|
78 |
+
output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
|
79 |
+
end = time.process_time()
|
80 |
+
logging.info("infer time is {} seconds", end-start)
|
81 |
+
return buildResponse(output['tts_speech'])
|
82 |
+
|
83 |
+
@app.post("/api/inference/instruct")
|
84 |
+
@app.get("/api/inference/instruct")
|
85 |
+
async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
|
86 |
+
start = time.process_time()
|
87 |
+
output = app.cosyvoice.inference_instruct(tts, role, instruct)
|
88 |
+
end = time.process_time()
|
89 |
+
logging.info("infer time is {} seconds", end-start)
|
90 |
+
return buildResponse(output['tts_speech'])
|
91 |
+
|
92 |
+
@app.get("/api/roles")
|
93 |
+
async def roles():
|
94 |
+
return {"roles": app.cosyvoice.list_avaliable_spks()}
|
95 |
+
|
96 |
+
@app.get("/", response_class=HTMLResponse)
|
97 |
+
async def root():
|
98 |
+
return """
|
99 |
+
<!DOCTYPE html>
|
100 |
+
<html lang=zh-cn>
|
101 |
+
<head>
|
102 |
+
<meta charset=utf-8>
|
103 |
+
<title>Api information</title>
|
104 |
+
</head>
|
105 |
+
<body>
|
106 |
+
Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
|
107 |
+
</body>
|
108 |
+
</html>
|
109 |
+
"""
|