CosyVoice commited on
Commit
c7d9754
2 Parent(s): 6cebcb3 6faabaa

Merge pull request #56 from iflamed/fastapi

Browse files

Add Fastapi server to serve TTS and download script

README.md CHANGED
@@ -152,4 +152,4 @@ You can also scan the QR code to join our official Dingding chat group.
152
  5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
153
 
154
  ## Disclaimer
155
- The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
 
152
  5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
153
 
154
  ## Disclaimer
155
+ The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
requirements.txt CHANGED
@@ -26,4 +26,6 @@ tensorboard==2.14.0
26
  torch==2.0.1
27
  torchaudio==2.0.2
28
  wget==3.2
 
 
29
  WeTextProcessing==1.0.3
 
26
  torch==2.0.1
27
  torchaudio==2.0.2
28
  wget==3.2
29
+ fastapi==0.111.0
30
+ fastapi-cli==0.0.4
31
  WeTextProcessing==1.0.3
runtime/python/fastapi_client.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import requests
4
+
5
+ def saveResponse(path, response):
6
+ # 以二进制写入模式打开文件
7
+ with open(path, 'wb') as file:
8
+ # 将响应的二进制内容写入文件
9
+ file.write(response.content)
10
+
11
+ def main():
12
+ api = args.api_base
13
+ if args.mode == 'sft':
14
+ url = api + "/api/inference/sft"
15
+ payload={
16
+ 'tts': args.tts_text,
17
+ 'role': args.spk_id
18
+ }
19
+ response = requests.request("POST", url, data=payload)
20
+ saveResponse(args.tts_wav, response)
21
+ elif args.mode == 'zero_shot':
22
+ url = api + "/api/inference/zero-shot"
23
+ payload={
24
+ 'tts': args.tts_text,
25
+ 'prompt': args.prompt_text
26
+ }
27
+ files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
28
+ response = requests.request("POST", url, data=payload, files=files)
29
+ saveResponse(args.tts_wav, response)
30
+ elif args.mode == 'cross_lingual':
31
+ url = api + "/api/inference/cross-lingual"
32
+ payload={
33
+ 'tts': args.tts_text,
34
+ }
35
+ files=[('audio', ('prompt_audio.wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
36
+ response = requests.request("POST", url, data=payload, files=files)
37
+ saveResponse(args.tts_wav, response)
38
+ else:
39
+ url = api + "/api/inference/instruct"
40
+ payload = {
41
+ 'tts': args.tts_text,
42
+ 'role': args.spk_id,
43
+ 'instruct': args.instruct_text
44
+ }
45
+ response = requests.request("POST", url, data=payload)
46
+ saveResponse(args.tts_wav, response)
47
+ logging.info("Response save to {}", args.tts_wav)
48
+
49
+ if __name__ == "__main__":
50
+ parser = argparse.ArgumentParser()
51
+ parser.add_argument('--api_base',
52
+ type=str,
53
+ default='http://127.0.0.1:6006')
54
+ parser.add_argument('--mode',
55
+ default='sft',
56
+ choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
57
+ help='request mode')
58
+ parser.add_argument('--tts_text',
59
+ type=str,
60
+ default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
61
+ parser.add_argument('--spk_id',
62
+ type=str,
63
+ default='中文女')
64
+ parser.add_argument('--prompt_text',
65
+ type=str,
66
+ default='希望你以后能够做的比我还好呦。')
67
+ parser.add_argument('--prompt_wav',
68
+ type=str,
69
+ default='../../zero_shot_prompt.wav')
70
+ parser.add_argument('--instruct_text',
71
+ type=str,
72
+ default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
73
+ parser.add_argument('--tts_wav',
74
+ type=str,
75
+ default='demo.wav')
76
+ args = parser.parse_args()
77
+ prompt_sr, target_sr = 16000, 22050
78
+ main()
runtime/python/fastapi_server.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set inference model
2
+ # export MODEL_DIR=pretrained_models/CosyVoice-300M-Instruct
3
+ # For development
4
+ # fastapi dev --port 6006 fastapi_server.py
5
+ # For production deployment
6
+ # fastapi run --port 6006 fastapi_server.py
7
+
8
+ import os
9
+ import sys
10
+ import io,time
11
+ from fastapi import FastAPI, Response, File, UploadFile, Form
12
+ from fastapi.responses import HTMLResponse
13
+ from contextlib import asynccontextmanager
14
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ sys.path.append('{}/../..'.format(ROOT_DIR))
16
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
17
+ from cosyvoice.cli.cosyvoice import CosyVoice
18
+ from cosyvoice.utils.file_utils import load_wav
19
+ import numpy as np
20
+ import torch
21
+ import torchaudio
22
+ import logging
23
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
24
+
25
+ class LaunchFailed(Exception):
26
+ pass
27
+
28
+ @asynccontextmanager
29
+ async def lifespan(app: FastAPI):
30
+ model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
31
+ if model_dir:
32
+ logging.info("MODEL_DIR is {}", model_dir)
33
+ app.cosyvoice = CosyVoice('../../'+model_dir)
34
+ # sft usage
35
+ logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
36
+ else:
37
+ raise LaunchFailed("MODEL_DIR environment must set")
38
+ yield
39
+
40
+ app = FastAPI(lifespan=lifespan)
41
+
42
+ def buildResponse(output):
43
+ buffer = io.BytesIO()
44
+ torchaudio.save(buffer, output, 22050, format="wav")
45
+ buffer.seek(0)
46
+ return Response(content=buffer.read(-1), media_type="audio/wav")
47
+
48
+ @app.post("/api/inference/sft")
49
+ @app.get("/api/inference/sft")
50
+ async def sft(tts: str = Form(), role: str = Form()):
51
+ start = time.process_time()
52
+ output = app.cosyvoice.inference_sft(tts, role)
53
+ end = time.process_time()
54
+ logging.info("infer time is {} seconds", end-start)
55
+ return buildResponse(output['tts_speech'])
56
+
57
+ @app.post("/api/inference/zero-shot")
58
+ async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
59
+ start = time.process_time()
60
+ prompt_speech = load_wav(audio.file, 16000)
61
+ prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
62
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
63
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
64
+
65
+ output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
66
+ end = time.process_time()
67
+ logging.info("infer time is {} seconds", end-start)
68
+ return buildResponse(output['tts_speech'])
69
+
70
+ @app.post("/api/inference/cross-lingual")
71
+ async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
72
+ start = time.process_time()
73
+ prompt_speech = load_wav(audio.file, 16000)
74
+ prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
75
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
76
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
77
+
78
+ output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
79
+ end = time.process_time()
80
+ logging.info("infer time is {} seconds", end-start)
81
+ return buildResponse(output['tts_speech'])
82
+
83
+ @app.post("/api/inference/instruct")
84
+ @app.get("/api/inference/instruct")
85
+ async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
86
+ start = time.process_time()
87
+ output = app.cosyvoice.inference_instruct(tts, role, instruct)
88
+ end = time.process_time()
89
+ logging.info("infer time is {} seconds", end-start)
90
+ return buildResponse(output['tts_speech'])
91
+
92
+ @app.get("/api/roles")
93
+ async def roles():
94
+ return {"roles": app.cosyvoice.list_avaliable_spks()}
95
+
96
+ @app.get("/", response_class=HTMLResponse)
97
+ async def root():
98
+ return """
99
+ <!DOCTYPE html>
100
+ <html lang=zh-cn>
101
+ <head>
102
+ <meta charset=utf-8>
103
+ <title>Api information</title>
104
+ </head>
105
+ <body>
106
+ Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
107
+ </body>
108
+ </html>
109
+ """