File size: 6,754 Bytes
193db9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import json
import logging
import re
from collections import Counter
import matplotlib.pyplot as plt
import pandas as pd
def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int:
"""Evaluate the buzz of a prediction against the clean answers."""
if isinstance(clean_answers, str):
print("clean_answers is a string")
clean_answers = [clean_answers]
pred = prediction.lower().strip()
if not pred:
return 0
for answer in clean_answers:
answer = answer.strip().lower()
if answer and answer in pred:
print(f"Found {answer} in {pred}")
return 1
return 0
def create_answer_html(answer: str):
"""Create HTML for the answer."""
return f"<div class='answer-header'>Answer:<br>{answer}</div>"
def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None):
"""Create HTML for tokens with hover capability and a colored header for the answer."""
try:
html_parts = []
ep = dict(eval_points)
marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set()
# Add a colored header for the answer
# html_parts.append(create_answer_html(answer))
for i, token in enumerate(tokens):
# Check if this token is a buzz point
values = ep.get(i, (None, 0, 0))
confidence, buzz_point, score = values
# Replace non-word characters for proper display in HTML
display_token = token
if not re.match(r"\w+", token):
display_token = token.replace(" ", " ")
# Add buzz marker class if it's a buzz point
if confidence is None:
css_class = ""
elif not buzz_point:
css_class = " guess-point no-buzz"
else:
css_class = f" guess-point buzz-{score}"
token_html = f'<span id="token-{i}" class="token{css_class}" data-index="{i}">{display_token}</span>'
if i in marker_indices:
token_html += "<span style='color: rgba(0,0,255,0.3);'>|</span>"
html_parts.append(token_html)
return f"<div class='token-container'>{''.join(html_parts)}</div>"
except Exception as e:
logging.error(f"Error creating token HTML: {e}", exc_info=True)
return f"<div class='token-container'>Error creating tokens: {str(e)}</div>"
def create_line_plot(eval_points, highlighted_index=-1):
"""Create a Gradio LinePlot of token values with optional highlighting using DataFrame."""
try:
# Create base confidence data
data = []
# Add buzz points to the plot
for i, (v, b) in eval_points:
color = "#ff4444" if b == 0 else "#228b22"
data.append(
{
"position": i,
"value": v,
"type": "buzz",
"highlight": True,
"color": color,
}
)
if highlighted_index >= 0:
# Add vertical line for the highlighted token
data.extend(
[
{
"position": highlighted_index,
"value": 0,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
{
"position": highlighted_index,
"value": 1,
"type": "hover-line",
"color": "#000000",
"highlight": True,
},
]
)
return pd.DataFrame(data)
except Exception as e:
logging.error(f"Error creating line plot: {e}", exc_info=True)
# Return an empty DataFrame with the expected columns
return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"])
def create_pyplot(tokens, eval_points, highlighted_index=-1):
"""Create a pyplot of token values with optional highlighting."""
plt.style.use("ggplot") # Set theme to grid paper
fig = plt.figure(figsize=(10, 6)) # Set figure size
ax = fig.add_subplot(111)
x = [0]
y = [0]
for i, (v, b, s) in eval_points:
x.append(i + 1)
y.append(v)
ax.plot(x, y, "o--", color="#4698cf")
for i, (v, b, s) in eval_points:
if not b:
continue
color = "green" if s else "red"
ax.plot(i + 1, v, "o", color=color)
if i >= len(tokens):
print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}")
ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center")
if highlighted_index >= 0:
# Add light vertical line for the highlighted token from 0 to 1
ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1)
ax.set_title("Buzz Confidence")
ax.set_xlabel("Token Index")
ax.set_ylabel("Confidence")
ax.set_xticks(x)
ax.set_xticklabels(x)
return fig
def create_scatter_pyplot(token_positions, scores):
"""Create a scatter plot of token positions and scores."""
plt.style.use("ggplot")
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111)
counts = Counter(zip(token_positions, scores))
X = []
Y = []
S = []
for (pos, score), size in counts.items():
X.append(pos)
Y.append(score)
S.append(size * 20)
ax.scatter(X, Y, color="#4698cf", s=S)
return fig
def update_plot(highlighted_index, state):
"""Update the plot when a token is hovered; add a vertical line on the plot."""
try:
if not state or state == "{}":
logging.warning("Empty state provided to update_plot")
return pd.DataFrame()
highlighted_index = int(highlighted_index) if highlighted_index else None
logging.info(f"Update plot triggered with token index: {highlighted_index}")
data = json.loads(state)
tokens = data.get("tokens", [])
values = data.get("values", [])
if not tokens or not values:
logging.warning("No tokens or values found in state")
return pd.DataFrame()
# Create updated plot with highlighting of the token point
# plot_data = create_line_plot(values, highlighted_index)
plot_data = create_pyplot(tokens, values, highlighted_index)
return plot_data
except Exception as e:
logging.error(f"Error updating plot: {e}")
return pd.DataFrame()
|