# This modules handles the task queue
import os
import multiprocessing
from typing import TypedDict
from datetime import datetime
import librosa
import numpy as np
from metrics import per, fer
from datasets import load_from_disk
from hf import get_repo_info, add_leaderboard_entry
from inference import clear_cache, load_model, transcribe
from codes import convert
leaderboard_lock = multiprocessing.Lock()
class Task(TypedDict):
status: str
display_name: str
repo_id: str
repo_hash: str
repo_last_modified: datetime
submission_timestamp: datetime
model_type: str
phone_code: str
model_bytes: int | None
url: str
error: str | None
tasks: list[Task] = []
def get_status(query: str) -> dict:
"""Check status of an evaluation task by repo_id or repo_hash"""
query = query.strip().lower()
if not query:
return {"error": "Please enter a model id or task id"}
for task in reversed(tasks):
if task["repo_id"].lower() == query or task["repo_hash"].lower() == query:
return dict(task)
return {"error": f"No results found for '{query}'"}
def start_eval_task(
display_name: str, repo_id: str, url: str, model_type: str, phone_code: str
) -> str:
"""Start evaluation task in background. Returns task ID that can be used to check status."""
repo_hash, last_modified, size_bytes = get_repo_info(repo_id)
# TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
task = Task(
status="submitted",
display_name=display_name,
repo_id=repo_id,
repo_hash=repo_hash,
repo_last_modified=last_modified,
submission_timestamp=datetime.now(),
model_type=model_type,
phone_code=phone_code,
model_bytes=size_bytes,
url=url,
error=None,
)
manager = multiprocessing.Manager()
task_proxy = manager.dict(task)
tasks.append(task_proxy) # type: ignore
multiprocessing.Process(
target=_eval_task, args=[task_proxy, leaderboard_lock]
).start()
return repo_hash
test_ds = load_from_disk(os.path.join(os.path.dirname(__file__), "data", "test"))
def _eval_task(task: Task, leaderboard_lock):
"""Background task to evaluate model and save updated results"""
try:
# Indicate task is processing
task["status"] = "evaluating"
# Evaluate model
average_per = 0
average_fer = 0
per_dataset_fers = {}
clear_cache()
model = load_model(task["repo_id"], task["model_type"])
for row in test_ds:
transcript = transcribe(row["audio"]["array"], task["model_type"], model) # type: ignore
if task["phone_code"] != "ipa":
transcript = convert(transcript, task["phone_code"], "ipa")
row_per = per(transcript, row["ipa"]) # type: ignore
row_fer = fer(transcript, row["ipa"]) # type: ignore
average_per += row_per
average_fer += row_fer
per_dataset_fers[row["dataset"]] = per_dataset_fers.get(row["dataset"], 0) + row_fer # type: ignore
for key in per_dataset_fers.keys():
per_dataset_fers[key] /= len(test_ds.filter(lambda r: r["dataset"] == key))
average_per /= len(test_ds)
average_fer /= len(test_ds)
# Save results
with leaderboard_lock:
add_leaderboard_entry(
display_name=task["display_name"],
repo_id=task["repo_id"],
repo_hash=task["repo_hash"],
repo_last_modified=task["repo_last_modified"],
submission_timestamp=task["submission_timestamp"],
average_per=average_per,
average_fer=average_fer,
url=task["url"],
model_bytes=task["model_bytes"],
per_dataset_fers=per_dataset_fers,
)
# Mark task as complete
task["status"] = "completed"
except Exception as e:
task["status"] = "failed"
task["error"] = str(e)
def run_sample_inference(audio, model_id: str, model_type: str, phone_code: str):
clear_cache()
# Load model
model = load_model(model_id, model_type)
# Format audio as monochannel 16 kHz float32
sample_rate, wav_array = audio
wav_array = wav_array.astype(np.float32)
if wav_array.ndim == 2 and wav_array.shape[1] == 2:
wav_array = np.mean(wav_array, axis=1)
wav_array = librosa.resample(y=wav_array, orig_sr=sample_rate, target_sr=16_000)
# Transcribe
transcript = transcribe(wav_array, model_type, model)
if phone_code != "ipa":
transcript = convert(transcript, phone_code, "ipa")
clear_cache()
return transcript