File size: 3,265 Bytes
01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from fastapi import Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
import io
from pydantic import BaseModel
import soundfile as sf
from fastapi.responses import FileResponse
from modules.normalization import text_normalize
from modules import generate_audio as generate
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.synthesize_audio import synthesize_audio
class TTSParams(BaseModel):
text: str = Query(..., description="Text to synthesize")
spk: str = Query(
"female2", description="Specific speaker by speaker name or speaker seed"
)
style: str = Query("chat", description="Specific style by style name")
temperature: float = Query(
0.3, description="Temperature for sampling (may be overridden by style or spk)"
)
top_P: float = Query(
0.5, description="Top P for sampling (may be overridden by style or spk)"
)
top_K: int = Query(
20, description="Top K for sampling (may be overridden by style or spk)"
)
seed: int = Query(
42, description="Seed for generate (may be overridden by style or spk)"
)
format: str = Query("mp3", description="Response audio format: [mp3,wav]")
prompt1: str = Query("", description="Text prompt for inference")
prompt2: str = Query("", description="Text prompt for inference")
prefix: str = Query("", description="Text prefix for inference")
bs: str = Query("8", description="Batch size for inference")
thr: str = Query("100", description="Threshold for sentence spliter")
async def synthesize_tts(params: TTSParams = Depends()):
try:
text = text_normalize(params.text, is_end=False)
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
spk = calc_params.get("spk", params.spk)
seed = params.seed or calc_params.get("seed", params.seed)
temperature = params.temperature or calc_params.get(
"temperature", params.temperature
)
prefix = params.prefix or calc_params.get("prefix", params.prefix)
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
batch_size = int(params.bs)
threshold = int(params.thr)
sample_rate, audio_data = synthesize_audio(
text,
temperature=temperature,
top_P=params.top_P,
top_K=params.top_K,
spk=spk,
infer_seed=seed,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
batch_size=batch_size,
spliter_threshold=threshold,
)
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)
if format == "mp3":
buffer = api_utils.wav_to_mp3(buffer)
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
import logging
logging.exception(e)
raise HTTPException(status_code=500, detail=str(e))
def setup(api_manager: APIManager):
api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts)
|