|
|
|
import json |
|
import logging |
|
import re |
|
from collections import Counter |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
|
|
|
|
def _make_answer_html(answer: str, clean_answers: list[str] = []) -> str: |
|
clean_answers = [a for a in clean_answers if len(a.split()) <= 6 and a != answer] |
|
additional_answers_html = "" |
|
if clean_answers: |
|
additional_answers_html = f"<span class='bonus-answer-text'> [or {', '.join(clean_answers)}]</span>" |
|
return f""" |
|
<div class='bonus-answer'> |
|
<span class='bonus-answer-label'>Answer: </span> |
|
<span class='bonus-answer-text'>{answer}</span> |
|
{additional_answers_html} |
|
</div> |
|
""" |
|
|
|
|
|
def _get_token_classes(confidence, buzz, score) -> str: |
|
if confidence is None: |
|
return "token" |
|
elif not buzz: |
|
return f"token guess-point buzz-{score}" |
|
else: |
|
return f"token guess-point buzz-{score}" |
|
|
|
|
|
def _create_token_tooltip_html(values) -> str: |
|
if not values: |
|
return "" |
|
confidence = values.get("confidence", 0) |
|
buzz = values.get("buzz", 0) |
|
score = values.get("score", 0) |
|
answer = values.get("answer", "") |
|
answer_tokens = answer.split() |
|
if len(answer_tokens) > 10: |
|
k = len(answer_tokens) - 10 |
|
answer = " ".join(answer_tokens[:10]) + f"...[{k} more words]" |
|
|
|
color = "#a3c9a3" if score else "#ebbec4" |
|
|
|
if values.get("logprob", None) is not None: |
|
prob = np.exp(values["logprob"]) |
|
prob_str = f"<p style='margin: 0 0 4px; color: #000;'> π <b style='color: #000;'>Output Probability:</b> {prob:.3f}</p>" |
|
else: |
|
prob_str = "" |
|
|
|
return f""" |
|
<div class="tooltip card" style="background-color: {color}; border-radius: 8px; padding: 12px; box-shadow: 2px 4px 8px rgba(0, 0, 0, 0.15);"> |
|
<div class="tooltip-content" style="font-family: 'Arial', sans-serif; color: #000;"> |
|
<h4 style="margin: 0 0 8px; color: #000;">π‘ Answer</h4> |
|
<p><code style="font-weight: bold; margin: 0 0 8px; color: #000;">{answer}</code></p> |
|
<p style="margin: 0 0 4px; color: #000;">π <b style="color: #000;">Confidence:</b> {confidence:.2f}</p> |
|
{prob_str} |
|
<p style="margin: 0; color: #000;">π <b style="color: #000;">Status:</b> {"β
Correct" if score else "β Incorrect" if buzz else "π« No Buzz"}</p> |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
def create_token_html(token: str, values: dict, i: int) -> str: |
|
confidence = values.get("confidence", None) |
|
buzz = values.get("buzz", 0) |
|
score = values.get("score", 0) |
|
|
|
|
|
display_token = f"{token} π¨" if buzz else f"{token} π" if values else token |
|
if not re.match(r"\w+", token): |
|
display_token = token.replace(" ", " ") |
|
|
|
css_class = _get_token_classes(confidence, buzz, score) |
|
|
|
tooltip_html = _create_token_tooltip_html(values) |
|
|
|
token_html = f'<span id="token-{i}" class="{css_class}" data-index="{i}">{display_token}{tooltip_html}</span>' |
|
|
|
|
|
return token_html |
|
|
|
|
|
def create_tossup_html( |
|
tokens: list[str], |
|
answer_primary: str, |
|
clean_answers: list[str], |
|
marker_indices: list[int] = [], |
|
eval_points: list[tuple[int, dict]] = [], |
|
) -> str: |
|
"""Create HTML for tokens with hover capability and a colored header for the answer.""" |
|
try: |
|
ep = dict(eval_points) |
|
marker_indices = set(marker_indices) |
|
|
|
html_tokens = [] |
|
for i, token in enumerate(tokens): |
|
token_html = create_token_html(token, ep.get(i, {}), i + 1) |
|
html_tokens.append(token_html) |
|
|
|
answer_html = _make_answer_html(answer_primary, clean_answers) |
|
return f""" |
|
<div class='bonus-container'> |
|
<div class='bonus-card'> |
|
<div class='tossup-question'> |
|
{"".join(html_tokens)} |
|
</div> |
|
{answer_html} |
|
</div> |
|
</div> |
|
""" |
|
except Exception as e: |
|
logging.error(f"Error creating token HTML: {e}", exc_info=True) |
|
return f"<div class='token-container'>Error creating tokens: {str(e)}</div>" |
|
|
|
|
|
def create_bonus_html(leadin: str, parts: list[dict]) -> str: |
|
|
|
leadin_html = f"<div class='bonus-leadin'>{leadin}</div>" |
|
parts_html = [] |
|
|
|
for i, part in enumerate(parts): |
|
question_text = part["part"] |
|
answer_html = _make_answer_html(part["answer_primary"], part["clean_answers"]) |
|
|
|
"<div class='bonus-part-number'>Part {i + 1}</div>" |
|
part_html = f""" |
|
<div class='bonus-part'> |
|
<div class='bonus-part-text'><b>#{i + 1}.</b> {question_text}</div> |
|
{answer_html} |
|
</div> |
|
""" |
|
parts_html.append(part_html) |
|
|
|
html_content = f""" |
|
<div class='bonus-container'> |
|
<div class='bonus-card'> |
|
{leadin_html} |
|
{"".join(parts_html)} |
|
</div> |
|
</div> |
|
""" |
|
|
|
|
|
clean_answers = [] |
|
for i, part in enumerate(parts): |
|
part_answers = [a for a in part["clean_answers"] if len(a.split()) <= 6] |
|
clean_answers.append(f"{i + 1}. {', '.join(part_answers)}") |
|
|
|
return html_content |
|
|
|
|
|
def create_tossup_confidence_pyplot( |
|
tokens: list[str], |
|
eval_points: list[tuple[int, dict]], |
|
confidence_threshold: float = 0.5, |
|
prob_threshold: float | None = None, |
|
) -> plt.Figure: |
|
"""Create a pyplot of token values with optional highlighting.""" |
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=(10, 4), dpi=300) |
|
ax = fig.add_subplot(111) |
|
x = [0] + [int(i + 1) for i, _ in eval_points] |
|
y_conf = [0] + [v["confidence"] for _, v in eval_points] |
|
logprob_values = [v["logprob"] for _, v in eval_points if v["logprob"] is not None] |
|
y_prob = [0] + [np.exp(v) for v in logprob_values] |
|
|
|
ax.plot(x, y_prob, "o-", color="#f2b150", label="Probability") |
|
ax.plot(x, y_conf, "o-", color="#4996de", label="Confidence") |
|
for i, v in eval_points: |
|
if not v["buzz"]: |
|
continue |
|
color = "green" if v["score"] else "red" |
|
conf = v["confidence"] |
|
ax.plot(i + 1, conf, "o", color=color, markerfacecolor="none", markersize=12, markeredgewidth=2.5) |
|
if v["logprob"] is not None: |
|
prob = np.exp(v["logprob"]) |
|
ax.plot(i + 1, prob, "o", color=color, markerfacecolor="none", markersize=12, markeredgewidth=2.5) |
|
if i >= len(tokens): |
|
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}") |
|
ax.annotate(f"{tokens[i]}", (i + 1, conf), textcoords="offset points", xytext=(0, 10), ha="center") |
|
|
|
|
|
ax.axhline(y=confidence_threshold, color="#9370DB", linestyle="--", xmin=0, xmax=1, label="Confidence Threshold") |
|
|
|
if prob_threshold is not None: |
|
ax.axhline(y=prob_threshold, color="#cf5757", linestyle="--", xmin=0, xmax=1, label="Probability Threshold") |
|
|
|
ax.set_title("Buzz Confidence") |
|
ax.set_xlabel("Token Index") |
|
ax.set_ylabel("Confidence") |
|
ax.set_xticks(x) |
|
ax.set_xticklabels(x) |
|
ax.legend() |
|
return fig |
|
|
|
|
|
def create_scatter_pyplot(token_positions: list[int], scores: list[int]) -> plt.Figure: |
|
"""Create a scatter plot of token positions and scores.""" |
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=(11, 5)) |
|
ax = fig.add_subplot(111) |
|
|
|
counts = Counter(zip(token_positions, scores)) |
|
X = [] |
|
Y = [] |
|
S = [] |
|
for (pos, score), size in counts.items(): |
|
X.append(pos) |
|
Y.append(score) |
|
S.append(size * 20) |
|
|
|
ax.scatter(X, Y, color="#4698cf", s=S) |
|
|
|
return fig |
|
|
|
|
|
def create_bonus_confidence_plot(parts: list[dict], model_outputs: list[dict]) -> plt.Figure: |
|
"""Create confidence plot for bonus parts.""" |
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=(10, 6)) |
|
ax = fig.add_subplot(111) |
|
|
|
|
|
x = range(1, len(parts) + 1) |
|
confidences = [output["confidence"] for output in model_outputs] |
|
scores = [output["score"] for output in model_outputs] |
|
|
|
|
|
bars = ax.bar(x, confidences, color="#4698cf") |
|
|
|
|
|
for i, score in enumerate(scores): |
|
bars[i].set_color("green" if score == 1 else "red") |
|
|
|
ax.set_title("Part Confidence") |
|
ax.set_xlabel("Part Number") |
|
ax.set_ylabel("Confidence") |
|
ax.set_xticks(x) |
|
ax.set_xticklabels([f"Part {i}" for i in x]) |
|
|
|
return fig |
|
|
|
|
|
def update_tossup_plot(highlighted_index: int, state: str) -> pd.DataFrame: |
|
"""Update the plot when a token is hovered; add a vertical line on the plot.""" |
|
try: |
|
if not state or state == "{}": |
|
logging.warning("Empty state provided to update_plot") |
|
return pd.DataFrame() |
|
|
|
highlighted_index = int(highlighted_index) if highlighted_index else None |
|
logging.info(f"Update plot triggered with token index: {highlighted_index}") |
|
|
|
data = json.loads(state) |
|
tokens = data.get("tokens", []) |
|
values = data.get("values", []) |
|
|
|
if not tokens or not values: |
|
logging.warning("No tokens or values found in state") |
|
return pd.DataFrame() |
|
|
|
|
|
|
|
plot_data = create_tossup_confidence_pyplot(tokens, values, highlighted_index) |
|
return plot_data |
|
except Exception as e: |
|
logging.error(f"Error updating plot: {e}") |
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
def create_df_entry(run_indices: list[int], run_outputs: list[dict]) -> dict: |
|
"""Create a dataframe entry from a list of model outputs.""" |
|
chosen_idx = None |
|
earliest_ok_idx = None |
|
is_correct = None |
|
for i, o in enumerate(run_outputs): |
|
if chosen_idx is None and o["buzz"]: |
|
chosen_idx = run_indices[o["position"] - 1] + 1 |
|
is_correct = o["score"] |
|
if earliest_ok_idx is None and o["score"]: |
|
earliest_ok_idx = run_indices[o["position"] - 1] + 1 |
|
if is_correct is None: |
|
is_correct = False |
|
|
|
|
|
|
|
|
|
if chosen_idx == -1: |
|
tossup_score = 0 |
|
elif chosen_idx == run_indices[-1] + 1: |
|
tossup_score = 5 if is_correct else 0 |
|
else: |
|
tossup_score = 10 if is_correct else -5 |
|
|
|
gap = None if (chosen_idx is None or earliest_ok_idx is None) else chosen_idx - earliest_ok_idx |
|
if earliest_ok_idx is None: |
|
cls = "hopeless" |
|
elif chosen_idx is None: |
|
cls = "never-buzzed" |
|
elif chosen_idx == earliest_ok_idx: |
|
cls = "best-buzz" |
|
elif chosen_idx > earliest_ok_idx: |
|
cls = "late-buzz" |
|
elif chosen_idx < earliest_ok_idx: |
|
cls = "premature" |
|
|
|
return { |
|
"chosen_idx": chosen_idx, |
|
"earliest_ok_idx": earliest_ok_idx, |
|
"gap": gap, |
|
"cls": cls, |
|
"tossup_score": tossup_score, |
|
"is_correct": int(is_correct), |
|
} |
|
|
|
|
|
def prepare_tossup_results_df(run_indices: list[list[int]], model_outputs: list[list[dict]]) -> pd.DataFrame: |
|
"""Create a dataframe from a list of model outputs.""" |
|
records = [] |
|
for indices, outputs in zip(run_indices, model_outputs): |
|
entry = create_df_entry(indices, outputs) |
|
records.append(entry) |
|
return pd.DataFrame.from_records(records) |
|
|
|
|
|
def create_tossup_eval_table(df: pd.DataFrame) -> pd.DataFrame: |
|
"""Create a table from a dataframe.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
positions = df["chosen_idx"].dropna() |
|
gaps = df["gap"].dropna() |
|
pos_gaps = gaps.loc[gaps >= 0] |
|
neg_gaps = gaps.loc[gaps < 0] |
|
|
|
mean_tossup_score = df["tossup_score"].sum() / len(df) |
|
|
|
return pd.DataFrame( |
|
[ |
|
{ |
|
"Tossup Score (10)": f"{mean_tossup_score:5.1f}", |
|
"Buzz Accuracy": f"{df['is_correct'].mean():5.1%}", |
|
"Buzz Position": f"{np.mean(positions):5.1f}", |
|
"+ve Gap": f"{pos_gaps.mean():5.1f}", |
|
"-ve Gap": f"{neg_gaps.mean():5.1f}", |
|
} |
|
] |
|
) |
|
|
|
|
|
def create_tossup_eval_dashboard(run_indices: list[list[int]], df: pd.DataFrame, *, figsize=(15, 8), title_prefix=""): |
|
""" |
|
Visualise buzzing behaviour with three sub-plots: |
|
|
|
1. Ceiling-accuracy vs. prefix length |
|
2. Scatter of earliest-correct idx vs. chosen-buzz idx |
|
3. Frequency distribution of narrative classes (vertical bars) |
|
|
|
Parameters |
|
---------- |
|
df : pd.DataFrame |
|
Output of `build_buzz_dataframe` β must contain |
|
columns: earliest_ok_idx, chosen_idx, cls. |
|
eval_indices : sequence[int] |
|
Token positions at which the model was probed. |
|
figsize : tuple, optional |
|
Figure size passed to `plt.subplots`. |
|
title_prefix : str, optional |
|
Prepended to each subplot title (useful when comparing models). |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
eval_indices = np.asarray(sorted({idx for indices in run_indices for idx in indices})) |
|
|
|
|
|
classes = [ |
|
"best-buzz", |
|
"late-buzz", |
|
"never-buzzed", |
|
"premature", |
|
"hopeless", |
|
] |
|
colors = ["tab:green", "tab:olive", "tab:orange", "tab:red", "tab:gray"] |
|
palette = dict(zip(classes, colors)) |
|
|
|
max_idx = eval_indices.max() * 1.25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.style.use("ggplot") |
|
fig = plt.figure(figsize=figsize) |
|
gs = fig.add_gridspec( |
|
nrows=2, |
|
ncols=3, |
|
height_ratios=[5, 1], |
|
width_ratios=[2.2, 2.2, 1], |
|
hspace=0.2, |
|
wspace=0.2, |
|
left=0.05, |
|
right=0.95, |
|
top=0.9, |
|
bottom=0.05, |
|
) |
|
|
|
ax_ceiling = fig.add_subplot(gs[0, 0]) |
|
ax_scatter = fig.add_subplot(gs[0, 1]) |
|
ax_bars = fig.add_subplot(gs[0, 2]) |
|
ax_desc = fig.add_subplot(gs[1, :]) |
|
ax_desc.axis("off") |
|
|
|
fig.suptitle("Buzzing behaviour", fontsize=16, fontweight="bold") |
|
|
|
|
|
|
|
|
|
ceiling = [((df["earliest_ok_idx"].notna()) & (df["earliest_ok_idx"] <= idx)).mean() for idx in eval_indices] |
|
ax_ceiling.plot(eval_indices, ceiling, marker="o", color="#4698cf") |
|
ax_ceiling.set_xlabel("Token index shown") |
|
ax_ceiling.set_ylabel("Proportion of questions correct") |
|
ax_ceiling.set_ylim(0, 1.01) |
|
ax_ceiling.set_title(f"{title_prefix}Ceiling accuracy vs. prefix") |
|
|
|
|
|
|
|
|
|
for cls in classes: |
|
sub = df[df["cls"] == cls] |
|
if sub.empty: |
|
continue |
|
x = sub["earliest_ok_idx"].fillna(max_idx) |
|
y = sub["chosen_idx"].fillna(max_idx) |
|
ax_scatter.scatter( |
|
x, |
|
y, |
|
label=cls, |
|
alpha=0.7, |
|
edgecolor="black", |
|
linewidth=1, |
|
marker="o", |
|
s=90, |
|
c=palette[cls], |
|
facecolor="none", |
|
) |
|
|
|
lim = max_idx |
|
ax_scatter.plot([0, lim], [0, lim], linestyle=":", linewidth=1) |
|
ax_scatter.set_xlim(0, lim) |
|
ax_scatter.set_ylim(0, lim) |
|
ax_scatter.set_xlabel("Earliest index with correct answer") |
|
ax_scatter.set_ylabel("Chosen buzz index") |
|
ax_scatter.set_title(f"{title_prefix}Earliest vs. chosen index") |
|
ax_scatter.legend(frameon=False, fontsize="small") |
|
|
|
|
|
|
|
|
|
counts = df["cls"].value_counts().reindex(classes).fillna(0) |
|
ax_bars.barh( |
|
counts.index, |
|
counts.values, |
|
color=[palette[c] for c in counts.index], |
|
alpha=0.7, |
|
edgecolor="black", |
|
linewidth=1, |
|
) |
|
ax_bars.set_xlabel("Number of questions") |
|
ax_bars.set_title(f"{title_prefix}Outcome distribution") |
|
|
|
|
|
from matplotlib.ticker import MaxNLocator |
|
|
|
ax_bars.xaxis.set_major_locator(MaxNLocator(integer=True)) |
|
|
|
|
|
|
|
|
|
descriptions = { |
|
"best-buzz": "Perfect timing. Buzzed at the earliest possible correct position", |
|
"late-buzz": "Missed opportunity. Buzzed correctly but later than optimal", |
|
"never-buzzed": "Missed opportunity. Never buzzed despite knowing the answer", |
|
"premature": "Incorrect buzz. Buzzing at a later position could have been correct", |
|
"hopeless": "Never knew the answer. No correct answer at any position", |
|
} |
|
|
|
y_pos = 1.0 |
|
|
|
for cls, color in zip(classes, colors): |
|
ax_desc.text( |
|
0.01, |
|
y_pos, |
|
f"β {cls}: {descriptions[cls]}", |
|
ha="left", |
|
va="top", |
|
color=color, |
|
fontweight="bold", |
|
fontsize=11, |
|
transform=ax_desc.transAxes, |
|
) |
|
|
|
y_pos -= 0.25 |
|
|
|
|
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dummy_model_outputs(n_entries=10, n_positions=5): |
|
"""Create dummy model outputs for testing.""" |
|
np.random.seed(42) |
|
dummy_outputs = [] |
|
|
|
for _ in range(n_entries): |
|
run_indices = sorted(np.random.choice(range(10, 50), n_positions, replace=False)) |
|
outputs = [] |
|
|
|
for i in range(n_positions): |
|
|
|
will_buzz = np.random.random() > 0.7 |
|
|
|
is_correct = np.random.random() > 0.4 |
|
|
|
outputs.append( |
|
{ |
|
"position": i + 1, |
|
"buzz": will_buzz, |
|
"score": 1 if is_correct else 0, |
|
"confidence": np.random.random(), |
|
"logprob": np.log(np.random.random()), |
|
"answer": f"Answer {i + 1}", |
|
} |
|
) |
|
|
|
dummy_outputs.append({"run_indices": run_indices, "outputs": outputs}) |
|
|
|
return dummy_outputs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|