Spaces:
Sleeping
Sleeping
import argparse | |
import asyncio | |
import functools | |
import json | |
import os | |
from io import BytesIO | |
import uvicorn | |
from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request | |
from fastapi.responses import StreamingResponse | |
from faster_whisper import WhisperModel | |
from starlette.staticfiles import StaticFiles | |
from starlette.templating import Jinja2Templates | |
from zhconv import convert | |
from utils.data_utils import remove_punctuation | |
from utils.utils import add_arguments, print_arguments | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
parser = argparse.ArgumentParser(description=__doc__) | |
add_arg = functools.partial(add_arguments, argparser=parser) | |
add_arg("host", type=str, default="0.0.0.0", help="") | |
add_arg("port", type=int, default=5000, help="") | |
add_arg("model_path", type=str, default="models/sam2ai/whisper-odia-small-finetune-int8-ct2", help="") | |
add_arg("use_gpu", type=bool, default=False, help="") | |
add_arg("use_int8", type=bool, default=True, help="") | |
add_arg("beam_size", type=int, default=10, help="") | |
add_arg("num_workers", type=int, default=2, help="") | |
add_arg("vad_filter", type=bool, default=True, help="") | |
add_arg("local_files_only", type=bool, default=True, help="") | |
args = parser.parse_args() | |
print_arguments(args) | |
# | |
assert os.path.exists(args.model_path), f"{args.model_path}" | |
# | |
if args.use_gpu: | |
if not args.use_int8: | |
model = WhisperModel(args.model_path, device="cuda", compute_type="float16", | |
num_workers=args.num_workers, local_files_only=args.local_files_only) | |
else: | |
model = WhisperModel(args.model_path, device="cuda", | |
compute_type="int8_float16", num_workers=args.num_workers, | |
local_files_only=args.local_files_only) | |
else: | |
model = WhisperModel(args.model_path, device="cpu", | |
compute_type="int8", num_workers=args.num_workers, | |
local_files_only=args.local_files_only) | |
# | |
# _, _ = model.transcribe("dataset/test.wav", beam_size=5) | |
app = FastAPI(title="OdiaGenAI Whisper ASR") | |
app.mount('/static', StaticFiles(directory='static'), name='static') | |
templates = Jinja2Templates(directory="templates") | |
model_semaphore = None | |
def release_model_semaphore(): | |
model_semaphore.release() | |
def recognition(file: File, to_simple: int, | |
remove_pun: int, language: str = "bn", | |
task: str = "transcribe" | |
): | |
segments, info = model.transcribe(file, beam_size=10, task=task, language=language, vad_filter=args.vad_filter) | |
for segment in segments: | |
text = segment.text | |
if to_simple == 1: | |
# text = convert(text, '') | |
pass | |
if remove_pun == 1: | |
# text = remove_punctuation(text) | |
pass | |
ret = {"result": text, "start": round(segment.start, 2), "end": round(segment.end, 2)} | |
# | |
yield json.dumps(ret).encode() + b"\0" | |
async def api_recognition_stream( | |
to_simple: int = Body(1, description="", embed=True), | |
remove_pun: int = Body(0, description="", embed=True), | |
language: str = Body("bn", description="", embed=True), | |
task: str = Body("transcribe", description="", embed=True), | |
audio: UploadFile = File(..., description="") | |
): | |
global model_semaphore | |
if language == "None": language = None | |
if model_semaphore is None: | |
model_semaphore = asyncio.Semaphore(5) | |
await model_semaphore.acquire() | |
contents = await audio.read() | |
data = BytesIO(contents) | |
generator = recognition( | |
file=data, to_simple=to_simple, | |
remove_pun=remove_pun, language=language, | |
task=task | |
) | |
background_tasks = BackgroundTasks() | |
background_tasks.add_task(release_model_semaphore) | |
return StreamingResponse(generator, background=background_tasks) | |
async def api_recognition( | |
to_simple: int = Body(1, description="", embed=True), | |
remove_pun: int = Body(0, description="", embed=True), | |
language: str = Body("bn", description="", embed=True), | |
task: str = Body("transcribe", description="", embed=True), | |
audio: UploadFile = File(..., description="") | |
): | |
if language == "None":language=None | |
contents = await audio.read() | |
data = BytesIO(contents) | |
generator = recognition( | |
file=data, to_simple=to_simple, | |
remove_pun=remove_pun, language=language, | |
task=task | |
) | |
results = [] | |
for output in generator: | |
output = json.loads(output[:-1].decode("utf-8")) | |
results.append(output) | |
ret = {"results": results, "code": 0} | |
return ret | |
async def index(request: Request): | |
return templates.TemplateResponse( | |
"index.html", {"request": request, "id": id} | |
) | |
if __name__ == '__main__': | |
uvicorn.run(app, host=args.host, port=args.port) | |