File size: 3,969 Bytes
01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from fastapi import HTTPException
from pydantic import BaseModel
import torch
from modules.speaker import speaker_mgr
from modules.api import utils as api_utils
from modules.api.Api import APIManager
class CreateSpeaker(BaseModel):
seed: int
name: str
gender: str
describe: str
tensor: list
class UpdateSpeaker(BaseModel):
id: str
name: str
gender: str
describe: str
tensor: list
class SpeakerDetail(BaseModel):
id: str
with_emb: bool = False
class SpeakersUpdate(BaseModel):
speakers: list
def setup(app: APIManager):
@app.get("/v1/speakers/list", response_model=api_utils.BaseResponse)
async def list_speakers():
return {
"message": "ok",
"data": [spk.to_json() for spk in speaker_mgr.list_speakers()],
}
@app.post("/v1/speakers/update", response_model=api_utils.BaseResponse)
async def update_speakers(request: SpeakersUpdate):
for spk in request.speakers:
speaker = speaker_mgr.get_speaker_by_id(spk["id"])
if speaker is None:
raise HTTPException(
status_code=404, detail=f"Speaker not found: {spk['id']}"
)
speaker.name = spk.get("name", speaker.name)
speaker.gender = spk.get("gender", speaker.gender)
speaker.describe = spk.get("describe", speaker.describe)
if (
spk.get("tensor")
and isinstance(spk["tensor"], list)
and len(spk["tensor"]) > 0
):
# number array => Tensor
speaker.emb = torch.tensor(spk["tensor"])
speaker_mgr.save_all()
return {"message": "ok", "data": None}
@app.post("/v1/speaker/create", response_model=api_utils.BaseResponse)
async def create_speaker(request: CreateSpeaker):
if (
request.tensor
and isinstance(request.tensor, list)
and len(request.tensor) > 0
):
# from tensor
tensor = torch.tensor(request.tensor)
speaker = speaker_mgr.create_speaker_from_tensor(
tensor=tensor,
name=request.name,
gender=request.gender,
describe=request.describe,
)
else:
# from seed
speaker = speaker_mgr.create_speaker_from_seed(
seed=request.seed,
name=request.name,
gender=request.gender,
describe=request.describe,
)
return {"message": "ok", "data": speaker.to_json()}
@app.post("/v1/speaker/refresh", response_model=api_utils.BaseResponse)
async def refresh_speakers():
speaker_mgr.refresh_speakers()
return {"message": "ok"}
@app.post("/v1/speaker/update", response_model=api_utils.BaseResponse)
async def update_speaker(request: UpdateSpeaker):
speaker = speaker_mgr.get_speaker_by_id(request.id)
if speaker is None:
raise HTTPException(
status_code=404, detail=f"Speaker not found: {request.id}"
)
speaker.name = request.name
speaker.gender = request.gender
speaker.describe = request.describe
if (
request.tensor
and isinstance(request.tensor, list)
and len(request.tensor) > 0
):
# number array => Tensor
speaker.emb = torch.tensor(request.tensor)
speaker_mgr.update_speaker(speaker)
return {"message": "ok"}
@app.post("/v1/speaker/detail", response_model=api_utils.BaseResponse)
async def speaker_detail(request: SpeakerDetail):
speaker = speaker_mgr.get_speaker_by_id(request.id)
if speaker is None:
raise HTTPException(status_code=404, detail="Speaker not found")
return {"message": "ok", "data": speaker.to_json(with_emb=request.with_emb)}
|