# This modules handles the task queue

import os
import multiprocessing
from typing import TypedDict
from datetime import datetime

import librosa
import numpy as np

from metrics import per, fer
from datasets import load_from_disk
from hf import get_repo_info, add_leaderboard_entry
from inference import clear_cache, load_model, transcribe
from codes import convert

leaderboard_lock = multiprocessing.Lock()


class Task(TypedDict):
    status: str
    display_name: str
    repo_id: str
    repo_hash: str
    repo_last_modified: datetime
    submission_timestamp: datetime
    model_type: str
    phone_code: str
    model_bytes: int | None
    url: str
    error: str | None


tasks: list[Task] = []


def get_status(query: str) -> dict:
    """Check status of an evaluation task by repo_id or repo_hash"""

    query = query.strip().lower()
    if not query:
        return {"error": "Please enter a model id or task id"}

    for task in reversed(tasks):
        if task["repo_id"].lower() == query or task["repo_hash"].lower() == query:
            return dict(task)

    return {"error": f"No results found for '{query}'"}


def start_eval_task(
    display_name: str, repo_id: str, url: str, model_type: str, phone_code: str
) -> str:
    """Start evaluation task in background. Returns task ID that can be used to check status."""

    repo_hash, last_modified, size_bytes = get_repo_info(repo_id)
    # TODO: check if hash is different from the most recent submission if any for repo_id, otherwise don't recompute
    task = Task(
        status="submitted",
        display_name=display_name,
        repo_id=repo_id,
        repo_hash=repo_hash,
        repo_last_modified=last_modified,
        submission_timestamp=datetime.now(),
        model_type=model_type,
        phone_code=phone_code,
        model_bytes=size_bytes,
        url=url,
        error=None,
    )

    manager = multiprocessing.Manager()
    task_proxy = manager.dict(task)
    tasks.append(task_proxy)  # type: ignore
    multiprocessing.Process(
        target=_eval_task, args=[task_proxy, leaderboard_lock]
    ).start()

    return repo_hash


test_ds = load_from_disk(os.path.join(os.path.dirname(__file__), "data", "test"))


def _eval_task(task: Task, leaderboard_lock):
    """Background task to evaluate model and save updated results"""
    try:
        # Indicate task is processing
        task["status"] = "evaluating"

        # Evaluate model
        average_per = 0
        average_fer = 0
        per_dataset_fers = {}

        clear_cache()
        model = load_model(task["repo_id"], task["model_type"])
        for row in test_ds:
            transcript = transcribe(row["audio"]["array"], task["model_type"], model)  # type: ignore
            if task["phone_code"] != "ipa":
                transcript = convert(transcript, task["phone_code"], "ipa")
            row_per = per(transcript, row["ipa"])  # type: ignore
            row_fer = fer(transcript, row["ipa"])  # type: ignore
            average_per += row_per
            average_fer += row_fer
            per_dataset_fers[row["dataset"]] = per_dataset_fers.get(row["dataset"], 0) + row_fer  # type: ignore
        for key in per_dataset_fers.keys():
            per_dataset_fers[key] /= len(test_ds.filter(lambda r: r["dataset"] == key))
        average_per /= len(test_ds)
        average_fer /= len(test_ds)

        # Save results
        with leaderboard_lock:
            add_leaderboard_entry(
                display_name=task["display_name"],
                repo_id=task["repo_id"],
                repo_hash=task["repo_hash"],
                repo_last_modified=task["repo_last_modified"],
                submission_timestamp=task["submission_timestamp"],
                average_per=average_per,
                average_fer=average_fer,
                url=task["url"],
                model_bytes=task["model_bytes"],
                per_dataset_fers=per_dataset_fers,
            )

        # Mark task as complete
        task["status"] = "completed"
    except Exception as e:
        task["status"] = "failed"
        task["error"] = str(e)


def run_sample_inference(audio, model_id: str, model_type: str, phone_code: str):
    clear_cache()

    # Load model
    model = load_model(model_id, model_type)

    # Format audio as monochannel 16 kHz float32
    sample_rate, wav_array = audio
    wav_array = wav_array.astype(np.float32)
    if wav_array.ndim == 2 and wav_array.shape[1] == 2:
        wav_array = np.mean(wav_array, axis=1)
    wav_array = librosa.resample(y=wav_array, orig_sr=sample_rate, target_sr=16_000)

    # Transcribe
    transcript = transcribe(wav_array, model_type, model)
    if phone_code != "ipa":
        transcript = convert(transcript, phone_code, "ipa")

    clear_cache()
    return transcript