Maharshi Gor
commited on
Commit
·
973519b
1
Parent(s):
193db9d
Enhance model provider detection and add repository management script. Added support for multi step agent.
Browse files- app.py +51 -2
- check_repos.py +27 -0
- src/components/quizbowl/bonus.py +7 -3
- src/components/quizbowl/tossup.py +9 -1
- src/llms.py +129 -0
- src/utils.py +3 -1
- src/workflows/qb/multi_step_agent.py +192 -0
- src/workflows/qb/simple_agent.py +0 -8
app.py
CHANGED
@@ -1,14 +1,60 @@
|
|
1 |
import datasets
|
2 |
import gradio as gr
|
|
|
|
|
3 |
|
4 |
from components.quizbowl.bonus import BonusInterface
|
5 |
from components.quizbowl.tossup import TossupInterface
|
6 |
from display.custom_css import css_pipeline, css_tossup
|
7 |
|
8 |
# Constants
|
9 |
-
from envs import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from workflows import factory
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
js_preamble = """
|
13 |
<link href="https://fonts.cdnfonts.com/css/roboto-mono" rel="stylesheet">
|
14 |
|
@@ -118,8 +164,11 @@ def main():
|
|
118 |
}
|
119 |
bonus_interface = BonusInterface(app, bonus_ds, AVAILABLE_MODELS, defaults)
|
120 |
|
121 |
-
app.queue(
|
122 |
|
123 |
|
124 |
if __name__ == "__main__":
|
|
|
|
|
|
|
125 |
main()
|
|
|
1 |
import datasets
|
2 |
import gradio as gr
|
3 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
4 |
+
from huggingface_hub import snapshot_download
|
5 |
|
6 |
from components.quizbowl.bonus import BonusInterface
|
7 |
from components.quizbowl.tossup import TossupInterface
|
8 |
from display.custom_css import css_pipeline, css_tossup
|
9 |
|
10 |
# Constants
|
11 |
+
from src.envs import (
|
12 |
+
API,
|
13 |
+
AVAILABLE_MODELS,
|
14 |
+
DEFAULT_SELECTIONS,
|
15 |
+
EVAL_REQUESTS_PATH,
|
16 |
+
EVAL_RESULTS_PATH,
|
17 |
+
PLAYGROUND_DATASET_NAMES,
|
18 |
+
QUEUE_REPO,
|
19 |
+
REPO_ID,
|
20 |
+
RESULTS_REPO,
|
21 |
+
THEME,
|
22 |
+
TOKEN,
|
23 |
+
)
|
24 |
from workflows import factory
|
25 |
|
26 |
+
|
27 |
+
def restart_space():
|
28 |
+
API.restart_space(repo_id=REPO_ID)
|
29 |
+
|
30 |
+
|
31 |
+
### Space initialisation
|
32 |
+
try:
|
33 |
+
print(EVAL_REQUESTS_PATH)
|
34 |
+
snapshot_download(
|
35 |
+
repo_id=QUEUE_REPO,
|
36 |
+
local_dir=EVAL_REQUESTS_PATH,
|
37 |
+
repo_type="dataset",
|
38 |
+
tqdm_class=None,
|
39 |
+
etag_timeout=30,
|
40 |
+
token=TOKEN,
|
41 |
+
)
|
42 |
+
except Exception:
|
43 |
+
restart_space()
|
44 |
+
try:
|
45 |
+
print(EVAL_RESULTS_PATH)
|
46 |
+
snapshot_download(
|
47 |
+
repo_id=RESULTS_REPO,
|
48 |
+
local_dir=EVAL_RESULTS_PATH,
|
49 |
+
repo_type="dataset",
|
50 |
+
tqdm_class=None,
|
51 |
+
etag_timeout=30,
|
52 |
+
token=TOKEN,
|
53 |
+
)
|
54 |
+
except Exception:
|
55 |
+
restart_space()
|
56 |
+
|
57 |
+
|
58 |
js_preamble = """
|
59 |
<link href="https://fonts.cdnfonts.com/css/roboto-mono" rel="stylesheet">
|
60 |
|
|
|
164 |
}
|
165 |
bonus_interface = BonusInterface(app, bonus_ds, AVAILABLE_MODELS, defaults)
|
166 |
|
167 |
+
app.queue(default_concurrency_limit=40).launch()
|
168 |
|
169 |
|
170 |
if __name__ == "__main__":
|
171 |
+
scheduler = BackgroundScheduler()
|
172 |
+
scheduler.add_job(restart_space, "interval", seconds=1800)
|
173 |
+
scheduler.start()
|
174 |
main()
|
check_repos.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi
|
2 |
+
|
3 |
+
from src.envs import QUEUE_REPO, RESULTS_REPO, TOKEN
|
4 |
+
|
5 |
+
|
6 |
+
def check_and_create_repos():
|
7 |
+
api = HfApi(token=TOKEN)
|
8 |
+
|
9 |
+
# Check and create queue repo
|
10 |
+
try:
|
11 |
+
api.repo_info(repo_id=QUEUE_REPO, repo_type="dataset")
|
12 |
+
print(f"Queue repository {QUEUE_REPO} exists")
|
13 |
+
except Exception:
|
14 |
+
print(f"Creating queue repository {QUEUE_REPO}")
|
15 |
+
api.create_repo(repo_id=QUEUE_REPO, repo_type="dataset", exist_ok=True, private=False)
|
16 |
+
|
17 |
+
# Check and create results repo
|
18 |
+
try:
|
19 |
+
api.repo_info(repo_id=RESULTS_REPO, repo_type="dataset")
|
20 |
+
print(f"Results repository {RESULTS_REPO} exists")
|
21 |
+
except Exception:
|
22 |
+
print(f"Creating results repository {RESULTS_REPO}")
|
23 |
+
api.create_repo(repo_id=RESULTS_REPO, repo_type="dataset", exist_ok=True, private=False)
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
check_and_create_repos()
|
src/components/quizbowl/bonus.py
CHANGED
@@ -10,7 +10,7 @@ from datasets import Dataset
|
|
10 |
|
11 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
12 |
from submission import submit
|
13 |
-
from workflows import
|
14 |
from workflows.qb.simple_agent import SimpleBonusAgent
|
15 |
from workflows.structs import ModelStep, Workflow
|
16 |
|
@@ -255,9 +255,13 @@ class BonusInterface:
|
|
255 |
"""Get the model outputs for a given question ID."""
|
256 |
outputs = []
|
257 |
leadin = example["leadin"]
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
for i, part in enumerate(example["parts"]):
|
260 |
-
agent = SimpleBonusAgent(workflow=pipeline_state.workflow)
|
261 |
# Run model for each part
|
262 |
part_output = agent.run(leadin, part["part"])
|
263 |
|
@@ -384,7 +388,7 @@ class BonusInterface:
|
|
384 |
)
|
385 |
|
386 |
self.submit_btn.click(
|
387 |
-
fn=self.
|
388 |
inputs=[
|
389 |
self.model_name_input,
|
390 |
self.description_input,
|
|
|
10 |
|
11 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
12 |
from submission import submit
|
13 |
+
from workflows.qb.multi_step_agent import MultiStepBonusAgent
|
14 |
from workflows.qb.simple_agent import SimpleBonusAgent
|
15 |
from workflows.structs import ModelStep, Workflow
|
16 |
|
|
|
255 |
"""Get the model outputs for a given question ID."""
|
256 |
outputs = []
|
257 |
leadin = example["leadin"]
|
258 |
+
workflow = pipeline_state.workflow
|
259 |
+
if len(workflow.steps) > 1:
|
260 |
+
agent = MultiStepBonusAgent(workflow)
|
261 |
+
else:
|
262 |
+
agent = SimpleBonusAgent(workflow)
|
263 |
|
264 |
for i, part in enumerate(example["parts"]):
|
|
|
265 |
# Run model for each part
|
266 |
part_output = agent.run(leadin, part["part"])
|
267 |
|
|
|
388 |
)
|
389 |
|
390 |
self.submit_btn.click(
|
391 |
+
fn=self.submit_model,
|
392 |
inputs=[
|
393 |
self.model_name_input,
|
394 |
self.description_input,
|
src/components/quizbowl/tossup.py
CHANGED
@@ -9,6 +9,7 @@ from datasets import Dataset
|
|
9 |
|
10 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
11 |
from submission import submit
|
|
|
12 |
from workflows.qb.simple_agent import SimpleTossupAgent
|
13 |
from workflows.structs import ModelStep, Workflow
|
14 |
|
@@ -21,6 +22,9 @@ from .plotting import (
|
|
21 |
update_plot,
|
22 |
)
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
|
26 |
"""Add model scores to the model outputs."""
|
@@ -291,7 +295,11 @@ class TossupInterface:
|
|
291 |
for run_idx in example["run_indices"]:
|
292 |
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
293 |
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
296 |
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
297 |
return outputs
|
|
|
9 |
|
10 |
from components.model_pipeline.model_pipeline import PipelineInterface, PipelineState
|
11 |
from submission import submit
|
12 |
+
from workflows.qb.multi_step_agent import MultiStepTossupAgent
|
13 |
from workflows.qb.simple_agent import SimpleTossupAgent
|
14 |
from workflows.structs import ModelStep, Workflow
|
15 |
|
|
|
22 |
update_plot,
|
23 |
)
|
24 |
|
25 |
+
# TODO: Error handling on run tossup and evaluate tossup and show correct messages
|
26 |
+
# TODO: ^^ Same for Bonus
|
27 |
+
|
28 |
|
29 |
def add_model_scores(model_outputs: list[dict], clean_answers: list[str], run_indices: list[int]) -> list[dict]:
|
30 |
"""Add model scores to the model outputs."""
|
|
|
295 |
for run_idx in example["run_indices"]:
|
296 |
question_runs.append(" ".join(tokens[: run_idx + 1]))
|
297 |
|
298 |
+
workflow = pipeline_state.workflow
|
299 |
+
if len(workflow.steps) > 1:
|
300 |
+
agent = MultiStepTossupAgent(workflow, buzz_threshold)
|
301 |
+
else:
|
302 |
+
agent = SimpleTossupAgent(workflow, buzz_threshold)
|
303 |
outputs = list(agent.run(question_runs, early_stop=early_stop))
|
304 |
outputs = add_model_scores(outputs, example["clean_answers"], example["run_indices"])
|
305 |
return outputs
|
src/llms.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import cohere
|
7 |
+
import json_repair
|
8 |
+
import numpy as np
|
9 |
+
from anthropic import Anthropic
|
10 |
+
from langchain_anthropic import ChatAnthropic
|
11 |
+
from langchain_cohere import ChatCohere
|
12 |
+
from langchain_openai import ChatOpenAI
|
13 |
+
from openai import OpenAI
|
14 |
+
from pydantic import BaseModel, Field
|
15 |
+
from rich import print as rprint
|
16 |
+
|
17 |
+
import utils
|
18 |
+
from envs import AVAILABLE_MODELS
|
19 |
+
|
20 |
+
|
21 |
+
class LLMOutput(BaseModel):
|
22 |
+
content: str = Field(description="The content of the response")
|
23 |
+
logprob: Optional[float] = Field(None, description="The log probability of the response")
|
24 |
+
|
25 |
+
|
26 |
+
def completion(model: str, system: str, prompt: str, response_format, logprobs: bool = False) -> str:
|
27 |
+
"""
|
28 |
+
Generate a completion from an LLM provider with structured output.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
model (str): Provider and model name in format "provider/model" (e.g. "OpenAI/gpt-4")
|
32 |
+
system (str): System prompt/instructions for the model
|
33 |
+
prompt (str): User prompt/input
|
34 |
+
response_format: Pydantic model defining the expected response structure
|
35 |
+
logprobs (bool, optional): Whether to return log probabilities. Defaults to False.
|
36 |
+
Note: Not supported by Anthropic models.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
dict: Contains:
|
40 |
+
- output: The structured response matching response_format
|
41 |
+
- logprob: (optional) Sum of log probabilities if logprobs=True
|
42 |
+
- prob: (optional) Exponential of logprob if logprobs=True
|
43 |
+
|
44 |
+
Raises:
|
45 |
+
ValueError: If logprobs=True with Anthropic models
|
46 |
+
"""
|
47 |
+
if model not in AVAILABLE_MODELS:
|
48 |
+
raise ValueError(f"Model {model} not supported")
|
49 |
+
model_name = AVAILABLE_MODELS[model]["model"]
|
50 |
+
provider = model.split("/")[0]
|
51 |
+
if provider == "Cohere":
|
52 |
+
return _cohere_completion(model_name, system, prompt, response_format, logprobs)
|
53 |
+
elif provider == "OpenAI":
|
54 |
+
return _openai_completion(model_name, system, prompt, response_format, logprobs)
|
55 |
+
elif provider == "Anthropic":
|
56 |
+
if logprobs:
|
57 |
+
raise ValueError("Anthropic does not support logprobs")
|
58 |
+
return _anthropic_completion(model_name, system, prompt, response_format)
|
59 |
+
else:
|
60 |
+
raise ValueError(f"Provider {provider} not supported")
|
61 |
+
|
62 |
+
|
63 |
+
def _cohere_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
|
64 |
+
messages = [
|
65 |
+
{"role": "system", "content": system},
|
66 |
+
{"role": "user", "content": prompt},
|
67 |
+
]
|
68 |
+
client = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
|
69 |
+
response = client.chat(
|
70 |
+
model=model,
|
71 |
+
messages=messages,
|
72 |
+
response_format={"type": "json_schema", "json_schema": response_model.model_json_schema()},
|
73 |
+
logprobs=logprobs,
|
74 |
+
)
|
75 |
+
output = {}
|
76 |
+
output["content"] = response.message.content[0].text
|
77 |
+
output["output"] = response_model.model_validate_json(response.message.content[0].text).model_dump()
|
78 |
+
if logprobs:
|
79 |
+
output["logprob"] = sum(lp.logprobs[0] for lp in response.logprobs)
|
80 |
+
output["prob"] = np.exp(output["logprob"])
|
81 |
+
return output
|
82 |
+
|
83 |
+
|
84 |
+
def _openai_completion(model: str, system: str, prompt: str, response_model, logprobs: bool = True) -> str:
|
85 |
+
messages = [
|
86 |
+
{"role": "system", "content": system},
|
87 |
+
{"role": "user", "content": prompt},
|
88 |
+
]
|
89 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
90 |
+
response = client.beta.chat.completions.parse(
|
91 |
+
model=model,
|
92 |
+
messages=messages,
|
93 |
+
response_format=response_model,
|
94 |
+
logprobs=logprobs,
|
95 |
+
)
|
96 |
+
output = {}
|
97 |
+
output["content"] = response.choices[0].message.content
|
98 |
+
output["output"] = response.choices[0].message.parsed.model_dump()
|
99 |
+
if logprobs:
|
100 |
+
output["logprob"] = sum(lp.logprob for lp in response.choices[0].logprobs.content)
|
101 |
+
output["prob"] = np.exp(output["logprob"])
|
102 |
+
return output
|
103 |
+
|
104 |
+
|
105 |
+
def _anthropic_completion(model: str, system: str, prompt: str, response_model) -> str:
|
106 |
+
llm = ChatAnthropic(model=model).with_structured_output(response_model, include_raw=True)
|
107 |
+
output = llm.invoke([("system", system), ("human", prompt)])
|
108 |
+
return {"content": output.raw, "output": output.parsed.model_dump()}
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
|
113 |
+
class ExplainedAnswer(BaseModel):
|
114 |
+
"""
|
115 |
+
The answer to the question and a terse explanation of the answer.
|
116 |
+
"""
|
117 |
+
|
118 |
+
answer: str = Field(description="The short answer to the question")
|
119 |
+
explanation: str = Field(description="5 words terse best explanation of the answer.")
|
120 |
+
|
121 |
+
model = "Anthropic/claude-3-5-sonnet-20240620"
|
122 |
+
system = "You are an accurate and concise explainer of scientific concepts."
|
123 |
+
prompt = "Which planet is closest to the sun in the Milky Way galaxy? Answer directly, no explanation needed."
|
124 |
+
|
125 |
+
# response = _cohere_completion("command-r", system, prompt, ExplainedAnswer, logprobs=True)
|
126 |
+
response = completion(model, system, prompt, ExplainedAnswer, logprobs=False)
|
127 |
+
rprint(response)
|
128 |
+
|
129 |
+
# %%
|
src/utils.py
CHANGED
@@ -8,8 +8,10 @@ def guess_model_provider(model_name: str):
|
|
8 |
model_name = model_name.lower()
|
9 |
if model_name.startswith("gpt-"):
|
10 |
return "OpenAI"
|
11 |
-
if "sonnet" in model_name or "claude" in model_name:
|
12 |
return "Anthropic"
|
|
|
|
|
13 |
raise ValueError(f"Model `{model_name}` not yet supported")
|
14 |
|
15 |
|
|
|
8 |
model_name = model_name.lower()
|
9 |
if model_name.startswith("gpt-"):
|
10 |
return "OpenAI"
|
11 |
+
if "sonnet" in model_name or "claude" in model_name or "haiku" in model_name:
|
12 |
return "Anthropic"
|
13 |
+
if "command" in model_name:
|
14 |
+
return "Cohere"
|
15 |
raise ValueError(f"Model `{model_name}` not yet supported")
|
16 |
|
17 |
|
src/workflows/qb/multi_step_agent.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from typing import Any, Iterable
|
3 |
+
|
4 |
+
from workflows.executors import execute_workflow
|
5 |
+
from workflows.structs import Workflow
|
6 |
+
|
7 |
+
|
8 |
+
def _get_workflow_response(workflow: Workflow, available_vars: dict[str, Any]) -> tuple[dict[str, Any], str, float]:
|
9 |
+
"""Get response from executing a complete workflow."""
|
10 |
+
start_time = time.time()
|
11 |
+
response, content = execute_workflow(workflow, available_vars, return_full_content=True)
|
12 |
+
response_time = time.time() - start_time
|
13 |
+
return response, content, response_time
|
14 |
+
|
15 |
+
|
16 |
+
class MultiStepTossupAgent:
|
17 |
+
"""Agent for handling tossup questions with multiple steps in the workflow."""
|
18 |
+
|
19 |
+
external_input_variable = "question_text"
|
20 |
+
output_variables = ["answer", "confidence"]
|
21 |
+
|
22 |
+
def __init__(self, workflow: Workflow, buzz_threshold: float):
|
23 |
+
"""Initialize the multi-step tossup agent.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
workflow: The workflow containing multiple steps
|
27 |
+
buzz_threshold: Confidence threshold for buzzing
|
28 |
+
"""
|
29 |
+
self.workflow = workflow
|
30 |
+
self.buzz_threshold = buzz_threshold
|
31 |
+
self.output_variables = list(workflow.outputs.keys())
|
32 |
+
|
33 |
+
# Validate input variables
|
34 |
+
if self.external_input_variable not in workflow.inputs:
|
35 |
+
raise ValueError(f"External input variable {self.external_input_variable} not found in workflow inputs")
|
36 |
+
|
37 |
+
# Validate output variables
|
38 |
+
for out_var in self.output_variables:
|
39 |
+
if out_var not in workflow.outputs:
|
40 |
+
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
41 |
+
|
42 |
+
def run(self, question_runs: list[str], early_stop: bool = True) -> Iterable[dict]:
|
43 |
+
"""Process a tossup question and decide when to buzz based on confidence.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
question_runs: Progressive reveals of the question text
|
47 |
+
early_stop: Whether to stop after the first buzz
|
48 |
+
|
49 |
+
Yields:
|
50 |
+
Dict containing:
|
51 |
+
- answer: The model's answer
|
52 |
+
- confidence: Confidence score
|
53 |
+
- buzz: Whether to buzz
|
54 |
+
- question_fragment: Current question text
|
55 |
+
- position: Current position in question
|
56 |
+
- full_response: Complete model response
|
57 |
+
- response_time: Time taken for response
|
58 |
+
- step_outputs: Outputs from each step
|
59 |
+
"""
|
60 |
+
for i, question_text in enumerate(question_runs):
|
61 |
+
# Execute the complete workflow
|
62 |
+
response, content, response_time = _get_workflow_response(
|
63 |
+
self.workflow, {self.external_input_variable: question_text}
|
64 |
+
)
|
65 |
+
|
66 |
+
buzz = response["confidence"] >= self.buzz_threshold
|
67 |
+
result = {
|
68 |
+
"answer": response["answer"],
|
69 |
+
"confidence": response["confidence"],
|
70 |
+
"buzz": buzz,
|
71 |
+
"question_fragment": question_text,
|
72 |
+
"position": i + 1,
|
73 |
+
"full_response": content,
|
74 |
+
"response_time": response_time,
|
75 |
+
"step_outputs": response.get("step_outputs", {}), # Include intermediate step outputs
|
76 |
+
}
|
77 |
+
|
78 |
+
yield result
|
79 |
+
|
80 |
+
# If we've reached the confidence threshold, buzz and stop
|
81 |
+
if early_stop and buzz:
|
82 |
+
return
|
83 |
+
|
84 |
+
|
85 |
+
class MultiStepBonusAgent:
|
86 |
+
"""Agent for handling bonus questions with multiple steps in the workflow."""
|
87 |
+
|
88 |
+
external_input_variables = ["leadin", "part"]
|
89 |
+
output_variables = ["answer", "confidence", "explanation"]
|
90 |
+
|
91 |
+
def __init__(self, workflow: Workflow):
|
92 |
+
"""Initialize the multi-step bonus agent.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
workflow: The workflow containing multiple steps
|
96 |
+
"""
|
97 |
+
self.workflow = workflow
|
98 |
+
self.output_variables = list(workflow.outputs.keys())
|
99 |
+
|
100 |
+
# Validate input variables
|
101 |
+
for input_var in self.external_input_variables:
|
102 |
+
if input_var not in workflow.inputs:
|
103 |
+
raise ValueError(f"External input variable {input_var} not found in workflow inputs")
|
104 |
+
|
105 |
+
# Validate output variables
|
106 |
+
for out_var in self.output_variables:
|
107 |
+
if out_var not in workflow.outputs:
|
108 |
+
raise ValueError(f"Output variable {out_var} not found in workflow outputs")
|
109 |
+
|
110 |
+
def run(self, leadin: str, part: str) -> dict:
|
111 |
+
"""Process a bonus part with the given leadin.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
leadin: The leadin text for the bonus question
|
115 |
+
part: The specific part text to answer
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Dict containing:
|
119 |
+
- answer: The model's answer
|
120 |
+
- confidence: Confidence score
|
121 |
+
- explanation: Explanation for the answer
|
122 |
+
- full_response: Complete model response
|
123 |
+
- response_time: Time taken for response
|
124 |
+
- step_outputs: Outputs from each step
|
125 |
+
"""
|
126 |
+
response, content, response_time = _get_workflow_response(
|
127 |
+
self.workflow,
|
128 |
+
{
|
129 |
+
"leadin": leadin,
|
130 |
+
"part": part,
|
131 |
+
},
|
132 |
+
)
|
133 |
+
|
134 |
+
return {
|
135 |
+
"answer": response["answer"],
|
136 |
+
"confidence": response["confidence"],
|
137 |
+
"explanation": response["explanation"],
|
138 |
+
"full_response": content,
|
139 |
+
"response_time": response_time,
|
140 |
+
"step_outputs": response.get("step_outputs", {}), # Include intermediate step outputs
|
141 |
+
}
|
142 |
+
|
143 |
+
|
144 |
+
# Example usage
|
145 |
+
if __name__ == "__main__":
|
146 |
+
# Load the Quizbowl dataset
|
147 |
+
from datasets import load_dataset
|
148 |
+
|
149 |
+
from workflows.factory import create_quizbowl_bonus_workflow, create_quizbowl_tossup_workflow
|
150 |
+
|
151 |
+
ds_name = "umdclip/leaderboard_co_set"
|
152 |
+
ds = load_dataset(ds_name, split="train")
|
153 |
+
|
154 |
+
# Create the agents with multi-step workflows
|
155 |
+
tossup_workflow = create_quizbowl_tossup_workflow()
|
156 |
+
tossup_agent = MultiStepTossupAgent(workflow=tossup_workflow, buzz_threshold=0.9)
|
157 |
+
|
158 |
+
bonus_workflow = create_quizbowl_bonus_workflow()
|
159 |
+
bonus_agent = MultiStepBonusAgent(workflow=bonus_workflow)
|
160 |
+
|
161 |
+
# Example for tossup mode
|
162 |
+
print("\n=== TOSSUP MODE EXAMPLE ===")
|
163 |
+
sample_question = ds[30]
|
164 |
+
print(sample_question["question_runs"][-1])
|
165 |
+
print(sample_question["gold_label"])
|
166 |
+
print()
|
167 |
+
question_runs = sample_question["question_runs"]
|
168 |
+
|
169 |
+
results = tossup_agent.run(question_runs, early_stop=True)
|
170 |
+
for result in results:
|
171 |
+
print(result["full_response"])
|
172 |
+
print(f"Guess at position {result['position']}: {result['answer']}")
|
173 |
+
print(f"Confidence: {result['confidence']}")
|
174 |
+
print("Step outputs:", result["step_outputs"])
|
175 |
+
if result["buzz"]:
|
176 |
+
print("Buzzed!\n")
|
177 |
+
|
178 |
+
# Example for bonus mode
|
179 |
+
print("\n=== BONUS MODE EXAMPLE ===")
|
180 |
+
sample_bonus = ds[31] # Assuming this is a bonus question
|
181 |
+
leadin = sample_bonus["leadin"]
|
182 |
+
parts = sample_bonus["parts"]
|
183 |
+
|
184 |
+
print(f"Leadin: {leadin}")
|
185 |
+
for i, part in enumerate(parts):
|
186 |
+
print(f"\nPart {i + 1}: {part['part']}")
|
187 |
+
result = bonus_agent.run(leadin, part["part"])
|
188 |
+
print(f"Answer: {result['answer']}")
|
189 |
+
print(f"Confidence: {result['confidence']}")
|
190 |
+
print(f"Explanation: {result['explanation']}")
|
191 |
+
print(f"Response time: {result['response_time']:.2f}s")
|
192 |
+
print("Step outputs:", result["step_outputs"])
|
src/workflows/qb/simple_agent.py
CHANGED
@@ -33,14 +33,6 @@ def _get_model_step_response(
|
|
33 |
return response, content, response_time
|
34 |
|
35 |
|
36 |
-
def _get_workflow_response(workflow: Workflow, available_vars: dict[str, Any]) -> tuple[dict[str, Any], str, float]:
|
37 |
-
"""Get response from the LLM model."""
|
38 |
-
start_time = time.time()
|
39 |
-
response, content = execute_workflow(workflow, available_vars, return_full_content=True)
|
40 |
-
response_time = time.time() - start_time
|
41 |
-
return response, content, response_time
|
42 |
-
|
43 |
-
|
44 |
class SimpleTossupAgent:
|
45 |
external_input_variable = "question_text"
|
46 |
output_variables = ["answer", "confidence"]
|
|
|
33 |
return response, content, response_time
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
class SimpleTossupAgent:
|
37 |
external_input_variable = "question_text"
|
38 |
output_variables = ["answer", "confidence"]
|