import os
import json
import time
import requests
import schedule
import threading
import gradio as gr
from tqdm import tqdm
from zoneinfo import ZoneInfo
from datetime import datetime
from huggingface_hub import HfApi
from tzlocal import get_localzone

DELAY = 1
TIMEOUT = 15
DOMAIN = "https://www.modelscope.cn"
HEADER = {
    "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537"
}


def fix_datetime(naive_time: datetime, target_tz=ZoneInfo("Asia/Shanghai")):
    if not naive_time:
        return None

    local_tz = get_localzone()
    aware_local = naive_time.replace(tzinfo=local_tz)
    return aware_local.astimezone(target_tz).strftime("%Y-%m-%d %H:%M:%S")


def get_spaces(username: str):
    try:
        studios = []
        spaces = HfApi().list_spaces(author=username)
        for space in spaces:
            space_id = space.id.replace("/", "-").replace("_", "-").lower()
            if space.sdk == "gradio":
                studios.append(f"https://{space_id}.hf.space")
            else:
                studios.append(f"https://{space_id}.static.hf.space")

        return studios

    except Exception as e:
        print(f"An error occurred in the request: {e}")

    return []


def activate_space(url: str):
    try:
        response = requests.get(url, headers=HEADER, timeout=TIMEOUT)
        response.raise_for_status()

    except Exception as e:
        print(e)


def get_studios(username: str):
    try:
        response = requests.put(
            f"{DOMAIN}/api/v1/studios/{username}/list",
            data=json.dumps(
                {
                    "PageNumber": 1,
                    "PageSize": 1000,
                    "Name": "",
                    "SortBy": "gmt_modified",
                    "Order": "desc",
                }
            ),
            headers=HEADER,
            timeout=TIMEOUT,
        )
        response.raise_for_status()
        spaces: list = response.json()["Data"]["Studios"]
        if spaces:
            studios = []
            for space in spaces:
                repo = f"{username}/{space['Name']}"
                if (
                    requests.get(
                        f"{DOMAIN}/api/v1/studio/{repo}/status",
                        headers=HEADER,
                        timeout=TIMEOUT,
                    ).json()["Data"]["Status"]
                    == "Expired"
                ):
                    studios.append(repo)

            return studios

    except requests.exceptions.Timeout as e:
        print(f"Timeout: {e}, retrying...")
        time.sleep(DELAY)
        return get_studios(username)

    except Exception as e:
        print(f"Requesting error: {e}")

    return []


def activate_studio(repo: str, holding_delay=5):
    repo_page = f"{DOMAIN}/studios/{repo}"
    status_api = f"{DOMAIN}/api/v1/studio/{repo}/status"
    start_expired_api = f"{DOMAIN}/api/v1/studio/{repo}/start_expired"
    try:
        response = requests.put(start_expired_api, headers=HEADER, timeout=TIMEOUT)
        response.raise_for_status()
        while (
            requests.get(status_api, headers=HEADER, timeout=TIMEOUT).json()["Data"][
                "Status"
            ]
            != "Running"
        ):
            requests.get(repo_page, headers=HEADER, timeout=TIMEOUT)
            time.sleep(holding_delay)

    except requests.exceptions.Timeout as e:
        print(f"Failed to activate {repo}: {e}, retrying...")
        activate_studio(repo)

    except Exception as e:
        print(e)


def activate(users=os.getenv("users")):
    spaces = []
    usernames = users.split(";")
    for user in tqdm(usernames, desc="Collecting spaces"):
        username = user.strip()
        if username:
            spaces += get_spaces(username)
            time.sleep(DELAY)

    for space in tqdm(spaces, desc="Activating spaces"):
        activate_space(space)
        time.sleep(DELAY)

    studios = []
    for user in tqdm(usernames, desc="Collecting studios"):
        username = user.strip()
        if username:
            studios += get_studios(username)
            time.sleep(DELAY)

    for studio in tqdm(studios, desc="Activating studios"):
        threading.Thread(target=activate_studio, args=(studio,), daemon=True).start()
        time.sleep(DELAY)

    print("\n".join(spaces + studios) + "\nActivation complete!")


def run_schedule():
    while True:
        schedule.run_pending()
        time.sleep(DELAY)


def monitor(period=os.getenv("period")):
    activate()
    print(f"Monitor is on and triggered every {period}h...")
    schedule.every(int(period)).hours.do(activate)
    threading.Thread(target=run_schedule, daemon=True).start()


def tasklist():
    jobs = schedule.get_jobs()
    for job in jobs:
        last_run = fix_datetime(job.last_run)
        if not last_run:
            last_run = "never"

        next_run = fix_datetime(job.next_run)
        return f"Every {job.interval}h do (last run: {last_run}, next run: {next_run})"

    return "None"


if __name__ == "__main__":
    monitor()
    gr.Interface(
        title="See current task status",
        fn=tasklist,
        inputs=None,
        outputs=gr.Textbox(label="Current task details"),
        flagging_mode="never",
    ).launch()