iflamed commited on
Commit
a4ab4ea
1 Parent(s): 3513376

support upload audio

Browse files
Files changed (2) hide show
  1. main.py +0 -40
  2. runtime/python/fastapi_server.py +102 -0
main.py DELETED
@@ -1,40 +0,0 @@
1
- import io,time
2
- from fastapi import FastAPI, Response
3
- from fastapi.responses import HTMLResponse
4
- from cosyvoice.cli.cosyvoice import CosyVoice
5
- import torchaudio
6
-
7
- cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-SFT')
8
- # sft usage
9
- print(cosyvoice.list_avaliable_spks())
10
- app = FastAPI()
11
-
12
- @app.get("/api/voice/tts")
13
- async def tts(query: str, role: str):
14
- start = time.process_time()
15
- output = cosyvoice.inference_sft(query, role)
16
- end = time.process_time()
17
- print("infer time:", end-start, "seconds")
18
- buffer = io.BytesIO()
19
- torchaudio.save(buffer, output['tts_speech'], 22050, format="wav")
20
- buffer.seek(0)
21
- return Response(content=buffer.read(-1), media_type="audio/wav")
22
-
23
- @app.get("/api/voice/roles")
24
- async def roles():
25
- return {"roles": cosyvoice.list_avaliable_spks()}
26
-
27
- @app.get("/", response_class=HTMLResponse)
28
- async def root():
29
- return """
30
- <!DOCTYPE html>
31
- <html lang=zh-cn>
32
- <head>
33
- <meta charset=utf-8>
34
- <title>Api information</title>
35
- </head>
36
- <body>
37
- Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
38
- </body>
39
- </html>
40
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
runtime/python/fastapi_server.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io,time
4
+ from fastapi import FastAPI, Response, File, UploadFile, Form
5
+ from fastapi.responses import HTMLResponse
6
+ from contextlib import asynccontextmanager
7
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
8
+ sys.path.append('{}/../..'.format(ROOT_DIR))
9
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
10
+ from cosyvoice.cli.cosyvoice import CosyVoice
11
+ from cosyvoice.utils.file_utils import load_wav
12
+ import numpy as np
13
+ import torch
14
+ import torchaudio
15
+ import logging
16
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
17
+
18
+ class LaunchFailed(Exception):
19
+ pass
20
+
21
+ @asynccontextmanager
22
+ async def lifespan(app: FastAPI):
23
+ model_dir = os.getenv("MODEL_DIR", "pretrained_models/CosyVoice-300M-SFT")
24
+ if model_dir:
25
+ logging.info("MODEL_DIR is {}", model_dir)
26
+ app.cosyvoice = CosyVoice('../../'+model_dir)
27
+ # sft usage
28
+ logging.info("Avaliable speakers {}", app.cosyvoice.list_avaliable_spks())
29
+ else:
30
+ raise LaunchFailed("MODEL_DIR environment must set")
31
+ yield
32
+
33
+ app = FastAPI(lifespan=lifespan)
34
+
35
+ def buildResponse(output):
36
+ buffer = io.BytesIO()
37
+ torchaudio.save(buffer, output, 22050, format="wav")
38
+ buffer.seek(0)
39
+ return Response(content=buffer.read(-1), media_type="audio/wav")
40
+
41
+ @app.post("/api/inference/sft")
42
+ @app.get("/api/inference/sft")
43
+ async def sft(tts: str = Form(), role: str = Form()):
44
+ start = time.process_time()
45
+ output = app.cosyvoice.inference_sft(tts, role)
46
+ end = time.process_time()
47
+ logging.info("infer time is {} seconds", end-start)
48
+ return buildResponse(output['tts_speech'])
49
+
50
+ @app.post("/api/inference/zero-shot")
51
+ async def zeroShot(tts: str = Form(), prompt: str = Form(), audio: UploadFile = File()):
52
+ start = time.process_time()
53
+ prompt_speech = load_wav(audio.file, 16000)
54
+ prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
55
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
56
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
57
+
58
+ output = app.cosyvoice.inference_zero_shot(tts, prompt, prompt_speech_16k)
59
+ end = time.process_time()
60
+ logging.info("infer time is {} seconds", end-start)
61
+ return buildResponse(output['tts_speech'])
62
+
63
+ @app.post("/api/inference/cross-lingual")
64
+ async def crossLingual(tts: str = Form(), audio: UploadFile = File()):
65
+ start = time.process_time()
66
+ prompt_speech = load_wav(audio.file, 16000)
67
+ prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
68
+ prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
69
+ prompt_speech_16k = prompt_speech_16k.float() / (2**15)
70
+
71
+ output = app.cosyvoice.inference_cross_lingual(tts, prompt_speech_16k)
72
+ end = time.process_time()
73
+ logging.info("infer time is {} seconds", end-start)
74
+ return buildResponse(output['tts_speech'])
75
+
76
+ @app.post("/api/inference/instruct")
77
+ @app.get("/api/inference/instruct")
78
+ async def instruct(tts: str = Form(), role: str = Form(), instruct: str = Form()):
79
+ start = time.process_time()
80
+ output = app.cosyvoice.inference_instruct(tts, role, instruct)
81
+ end = time.process_time()
82
+ logging.info("infer time is {} seconds", end-start)
83
+ return buildResponse(output['tts_speech'])
84
+
85
+ @app.get("/api/roles")
86
+ async def roles():
87
+ return {"roles": app.cosyvoice.list_avaliable_spks()}
88
+
89
+ @app.get("/", response_class=HTMLResponse)
90
+ async def root():
91
+ return """
92
+ <!DOCTYPE html>
93
+ <html lang=zh-cn>
94
+ <head>
95
+ <meta charset=utf-8>
96
+ <title>Api information</title>
97
+ </head>
98
+ <body>
99
+ Get the supported tones from the Roles API first, then enter the tones and textual content in the TTS API for synthesis. <a href='./docs'>Documents of API</a>
100
+ </body>
101
+ </html>
102
+ """