Maharshi Gor
Squash merge dictify-states into main
9756440
raw
history blame
5.01 kB
import datasets
import gradio as gr
from apscheduler.schedulers.background import BackgroundScheduler
from huggingface_hub import snapshot_download
from loguru import logger
import populate
from about import LEADERBOARD_INTRODUCTION_TEXT, LEADERBOARD_TITLE
from app_configs import DEFAULT_SELECTIONS, THEME
from components.quizbowl.bonus import BonusInterface
from components.quizbowl.tossup import TossupInterface
from display.css_html_js import fonts_header, js_head, leaderboard_css
from display.custom_css import css_bonus, css_pipeline, css_tossup
from display.guide import BUILDING_MARKDOWN, GUIDE_MARKDOWN, QUICKSTART_MARKDOWN
from display.utils import AutoEvalColumn, fields
# Constants
from envs import (
API,
EVAL_REQUESTS_PATH,
EVAL_RESULTS_PATH,
LEADERBOARD_REFRESH_INTERVAL,
PLAYGROUND_DATASET_NAMES,
QUEUE_REPO,
REPO_ID,
RESULTS_REPO,
SERVER_REFRESH_INTERVAL,
)
from workflows import factory
from workflows.configs import AVAILABLE_MODELS
def restart_space():
API.restart_space(repo_id=REPO_ID)
def download_dataset_snapshot(repo_id, local_dir):
try:
logger.info(f"Downloading dataset snapshot from {repo_id} to {local_dir}")
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
repo_type="dataset",
tqdm_class=None,
)
except Exception as e:
logger.error(f"Error downloading dataset snapshot from {repo_id} to {local_dir}: {e}. Restarting space.")
restart_space()
download_dataset_snapshot(QUEUE_REPO, EVAL_REQUESTS_PATH)
def fetch_leaderboard_df():
logger.info("Leaderboard fetched...")
download_dataset_snapshot(RESULTS_REPO, EVAL_RESULTS_PATH)
return populate.get_leaderboard_df(EVAL_RESULTS_PATH)
def load_dataset(mode: str):
if mode == "tossup":
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["tossup"], split="eval")
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
elif mode == "bonus":
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["bonus"], split="eval")
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10)
else:
raise ValueError(f"Invalid mode: {mode}")
return ds
def get_default_tab_id(request: gr.Request):
logger.info(f"Request: {request}")
tab_key_value = request.query_params.get("tab", "tossup")
return gr.update(selected=tab_key_value)
if __name__ == "__main__":
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=SERVER_REFRESH_INTERVAL)
scheduler.start()
css = css_pipeline + css_tossup + css_bonus + leaderboard_css
head = fonts_header + js_head
tossup_ds = load_dataset("tossup")
bonus_ds = load_dataset("bonus")
with gr.Blocks(
css=css,
head=head,
theme=THEME,
title="Quizbowl Bot",
) as demo:
with gr.Row():
gr.Markdown("## Welcome to Quizbowl Bot! This is a tool for creating and testing quizbowl agents.")
with gr.Tabs() as gtab:
with gr.Tab("πŸ›ŽοΈ Tossup Agents", id="tossup"):
defaults = DEFAULT_SELECTIONS["tossup"] | {
"init_workflow": factory.create_simple_qb_tossup_workflow(),
}
tossup_interface = TossupInterface(demo, tossup_ds, AVAILABLE_MODELS, defaults)
with gr.Tab("πŸ™‹πŸ»β€β™‚οΈ Bonus Round Agents", id="bonus"):
defaults = DEFAULT_SELECTIONS["bonus"] | {
"init_workflow": factory.create_simple_qb_bonus_workflow(),
}
bonus_interface = BonusInterface(demo, bonus_ds, AVAILABLE_MODELS, defaults)
with gr.Tab("πŸ… Leaderboard", elem_id="llm-benchmark-tab-table", id="leaderboard"):
leaderboard_timer = gr.Timer(LEADERBOARD_REFRESH_INTERVAL)
gr.Markdown("<a id='leaderboard' href='#leaderboard'>QANTA Leaderboard</a>")
gr.Markdown(LEADERBOARD_INTRODUCTION_TEXT)
refresh_btn = gr.Button("πŸ”„ Refresh")
leaderboard_table = gr.Dataframe(
value=fetch_leaderboard_df,
every=leaderboard_timer,
headers=[c.name for c in fields(AutoEvalColumn)],
datatype=[c.type for c in fields(AutoEvalColumn)],
elem_id="leaderboard-table",
interactive=False,
visible=True,
)
refresh_btn.click(fn=fetch_leaderboard_df, inputs=[], outputs=leaderboard_table)
with gr.Tab("❓ Help", id="help"):
with gr.Row():
with gr.Column():
gr.Markdown(QUICKSTART_MARKDOWN)
with gr.Column():
gr.Markdown(BUILDING_MARKDOWN)
demo.queue(default_concurrency_limit=40).launch()