Maharshi Gor
Enhance model provider detection and add repository management script. Added support for multi step agent.
973519b
import datasets | |
import gradio as gr | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import snapshot_download | |
from components.quizbowl.bonus import BonusInterface | |
from components.quizbowl.tossup import TossupInterface | |
from display.custom_css import css_pipeline, css_tossup | |
# Constants | |
from src.envs import ( | |
API, | |
AVAILABLE_MODELS, | |
DEFAULT_SELECTIONS, | |
EVAL_REQUESTS_PATH, | |
EVAL_RESULTS_PATH, | |
PLAYGROUND_DATASET_NAMES, | |
QUEUE_REPO, | |
REPO_ID, | |
RESULTS_REPO, | |
THEME, | |
TOKEN, | |
) | |
from workflows import factory | |
def restart_space(): | |
API.restart_space(repo_id=REPO_ID) | |
### Space initialisation | |
try: | |
print(EVAL_REQUESTS_PATH) | |
snapshot_download( | |
repo_id=QUEUE_REPO, | |
local_dir=EVAL_REQUESTS_PATH, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
token=TOKEN, | |
) | |
except Exception: | |
restart_space() | |
try: | |
print(EVAL_RESULTS_PATH) | |
snapshot_download( | |
repo_id=RESULTS_REPO, | |
local_dir=EVAL_RESULTS_PATH, | |
repo_type="dataset", | |
tqdm_class=None, | |
etag_timeout=30, | |
token=TOKEN, | |
) | |
except Exception: | |
restart_space() | |
js_preamble = """ | |
<link href="https://fonts.cdnfonts.com/css/roboto-mono" rel="stylesheet"> | |
<script> | |
const gradioApp = document.getElementsByTagName('gradio-app')[0]; | |
console.log("Gradio app:", gradioApp); | |
console.log(gradioApp.querySelectorAll('.token')); | |
console.log(document.querySelectorAll('.token')); | |
// Function to trigger Python callback | |
const setHiddenIndex = (index) => { | |
console.log("Setting hidden index to:", index); | |
const hiddenIndex = gradioApp.querySelector("#hidden-index textarea"); | |
if (hiddenIndex) { | |
hiddenIndex.value = index; | |
let event = new Event("input", { bubbles: true}); | |
Object.defineProperty(event, "target", { value: hiddenIndex}); | |
hiddenIndex.dispatchEvent(event); | |
} | |
}; | |
// Add event listeners to all tokens | |
function setupTokenListeners() { | |
const tokens = gradioApp.querySelectorAll('.token'); | |
console.log("Tokens:", tokens); | |
tokens.forEach(token => { | |
token.addEventListener('mouseover', function() { | |
const index = parseInt(this.getAttribute('data-index')); | |
console.log("Mouseover token index:", index); | |
// Reset all tokens | |
gradioApp.querySelectorAll('.token').forEach(el => { | |
el.classList.remove('highlighted'); | |
}); | |
// Highlight this token | |
this.classList.add('highlighted'); | |
// Update the hidden index to trigger the Python callback | |
setHiddenIndex(index); | |
}); | |
}); | |
} | |
console.log("Preamble complete"); | |
document.addEventListener("DOMContentLoaded", function() { | |
// Setup initial listeners | |
console.log("DOM fully loaded and parsed"); | |
setupTokenListeners(); | |
// Setup a mutation observer to handle dynamically added tokens | |
const observer = new MutationObserver(function(mutations) { | |
mutations.forEach(function(mutation) { | |
if (mutation.addedNodes.length) { | |
setupTokenListeners(); | |
} | |
}); | |
}); | |
// Start observing the token container for changes | |
const tokenContainer = gradioApp.querySelector('.token-container'); | |
console.log("Token container:", tokenContainer); | |
if (tokenContainer) { | |
observer.observe(tokenContainer.parentNode, { childList: true, subtree: true }); | |
} | |
console.log("Listener setup complete"); | |
}); | |
</script> | |
""" | |
def load_dataset(mode: str): | |
if mode == "tossup": | |
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["tossup"], split="eval") | |
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10) | |
elif mode == "bonus": | |
ds = datasets.load_dataset(PLAYGROUND_DATASET_NAMES["bonus"], split="eval") | |
ds = ds.filter(lambda x: x["qid"].split("-")[2] == "1" and int(x["qid"].split("-")[3]) <= 10) | |
else: | |
raise ValueError(f"Invalid mode: {mode}") | |
return ds | |
def main(): | |
tossup_ds = load_dataset("tossup") | |
bonus_ds = load_dataset("bonus") | |
app = gr.Blocks( | |
css=css_pipeline + css_tossup, | |
head=js_preamble, | |
theme=THEME, | |
title="Quizbowl Bot", | |
) | |
with app: | |
with gr.Tabs(): | |
with gr.Tab("Tossup Agents"): | |
defaults = DEFAULT_SELECTIONS["tossup"] | { | |
"init_workflow": factory.create_quizbowl_simple_workflow(), | |
"simple_workflow": False, | |
} | |
tossup_interface = TossupInterface(app, tossup_ds, AVAILABLE_MODELS, defaults) | |
# ModelStepComponent(value=factory.create_quizbowl_simple_step()) | |
with gr.Tab("Bonus Round Agents"): | |
defaults = DEFAULT_SELECTIONS["bonus"] | { | |
"init_workflow": factory.create_quizbowl_bonus_simple_workflow(), | |
"simple_workflow": True, | |
} | |
bonus_interface = BonusInterface(app, bonus_ds, AVAILABLE_MODELS, defaults) | |
app.queue(default_concurrency_limit=40).launch() | |
if __name__ == "__main__": | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(restart_space, "interval", seconds=1800) | |
scheduler.start() | |
main() | |