import json
import logging
import re
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int:
"""Evaluate the buzz of a prediction against the clean answers."""
if isinstance(clean_answers, str):
print("clean_answers is a string")
clean_answers = [clean_answers]
pred = prediction.lower().strip()
if not pred:
return 0
for answer in clean_answers:
answer = answer.strip().lower()
if answer and answer in pred:
print(f"Found {answer} in {pred}")
return 1
return 0
def create_answer_html(answer: str):
"""Create HTML for the answer."""
return f"
"
def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None):
"""Create HTML for tokens with hover capability and a colored header for the answer."""
try:
html_parts = []
ep = dict(eval_points)
marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set()
# Add a colored header for the answer
# html_parts.append(create_answer_html(answer))
for i, token in enumerate(tokens):
# Check if this token is a buzz point
values = ep.get(i, (None, 0, 0))
confidence, buzz_point, score = values
# Replace non-word characters for proper display in HTML
display_token = token
if not re.match(r"\w+", token):
display_token = token.replace(" ", " ")
# Add buzz marker class if it's a buzz point
if confidence is None:
css_class = ""
elif not buzz_point:
css_class = " guess-point no-buzz"
else:
css_class = f" guess-point buzz-{score}"
token_html = f'{display_token}'
if i in marker_indices:
token_html += "|"
html_parts.append(token_html)
return f"{''.join(html_parts)}
"
except Exception as e:
logging.error(f"Error creating token HTML: {e}", exc_info=True)
return f"Error creating tokens: {str(e)}
"
def create_line_plot(eval_points, highlighted_index=-1):
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
try:
# Create base confidence data
data = []
# Add buzz points to the plot
for i, (v, b) in eval_points:
color = "#ff4444" if b == 0 else "#228b22"
data.append(
{
"position": i,
"value": v,
"type": "buzz",
"highlight": True,
"color": color,
}
)
if highlighted_index >= 0:
# Add vertical line for the highlighted token
data.extend(
[
{
"position": highlighted_index,
"value": 0,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
{
"position": highlighted_index,
"value": 1,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
]
)
return pd.DataFrame(data)
except Exception as e:
logging.error(f"Error creating line plot: {e}", exc_info=True)
# Return an empty DataFrame with the expected columns
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
def create_pyplot(tokens, eval_points, highlighted_index=-1):
"""Create a pyplot of token values with optional highlighting."""
plt.style.use("ggplot") # Set theme to grid paper
fig = plt.figure(figsize=(10, 6)) # Set figure size
ax = fig.add_subplot(111)
x = [0]
y = [0]
for i, (v, b, s) in eval_points:
x.append(i + 1)
y.append(v)
ax.plot(x, y, "o--", color="#4698cf")
for i, (v, b, s) in eval_points:
if not b:
continue
color = "green" if s else "red"
ax.plot(i + 1, v, "o", color=color)
if i >= len(tokens):
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center")
if highlighted_index >= 0:
# Add light vertical line for the highlighted token from 0 to 1
ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1)
ax.set_title("Buzz Confidence")
ax.set_xlabel("Token Index")
ax.set_ylabel("Confidence")
ax.set_xticks(x)
ax.set_xticklabels(x)
return fig
def create_scatter_pyplot(token_positions, scores):
"""Create a scatter plot of token positions and scores."""
plt.style.use("ggplot")
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
counts = Counter(zip(token_positions, scores))
X = []
Y = []
S = []
for (pos, score), size in counts.items():
X.append(pos)
Y.append(score)
S.append(size * 20)
ax.scatter(X, Y, color="#4698cf", s=S)
return fig
def update_plot(highlighted_index, state):
"""Update the plot when a token is hovered; add a vertical line on the plot."""
try:
if not state or state == "{}":
logging.warning("Empty state provided to update_plot")
return pd.DataFrame()
highlighted_index = int(highlighted_index) if highlighted_index else None
logging.info(f"Update plot triggered with token index: {highlighted_index}")
data = json.loads(state)
tokens = data.get("tokens", [])
values = data.get("values", [])
if not tokens or not values:
logging.warning("No tokens or values found in state")
return pd.DataFrame()
# Create updated plot with highlighting of the token point
# plot_data = create_line_plot(values, highlighted_index)
plot_data = create_pyplot(tokens, values, highlighted_index)
return plot_data
except Exception as e:
logging.error(f"Error updating plot: {e}")
return pd.DataFrame()