import json import logging import re from collections import Counter import matplotlib.pyplot as plt import pandas as pd def evaluate_buzz(prediction: str, clean_answers: list[str] | str) -> int: """Evaluate the buzz of a prediction against the clean answers.""" if isinstance(clean_answers, str): print("clean_answers is a string") clean_answers = [clean_answers] pred = prediction.lower().strip() if not pred: return 0 for answer in clean_answers: answer = answer.strip().lower() if answer and answer in pred: print(f"Found {answer} in {pred}") return 1 return 0 def create_answer_html(answer: str): """Create HTML for the answer.""" return f"
Answer:
{answer}
" def create_tokens_html(tokens: list[str], eval_points: list[tuple], answer: str, marker_indices: list[int] = None): """Create HTML for tokens with hover capability and a colored header for the answer.""" try: html_parts = [] ep = dict(eval_points) marker_indices = set(marker_indices) if isinstance(marker_indices, list) else set() # Add a colored header for the answer # html_parts.append(create_answer_html(answer)) for i, token in enumerate(tokens): # Check if this token is a buzz point values = ep.get(i, (None, 0, 0)) confidence, buzz_point, score = values # Replace non-word characters for proper display in HTML display_token = token if not re.match(r"\w+", token): display_token = token.replace(" ", " ") # Add buzz marker class if it's a buzz point if confidence is None: css_class = "" elif not buzz_point: css_class = " guess-point no-buzz" else: css_class = f" guess-point buzz-{score}" token_html = f'{display_token}' if i in marker_indices: token_html += "|" html_parts.append(token_html) return f"
{''.join(html_parts)}
" except Exception as e: logging.error(f"Error creating token HTML: {e}", exc_info=True) return f"
Error creating tokens: {str(e)}
" def create_line_plot(eval_points, highlighted_index=-1): """Create a Gradio LinePlot of token values with optional highlighting using DataFrame.""" try: # Create base confidence data data = [] # Add buzz points to the plot for i, (v, b) in eval_points: color = "#ff4444" if b == 0 else "#228b22" data.append( { "position": i, "value": v, "type": "buzz", "highlight": True, "color": color, } ) if highlighted_index >= 0: # Add vertical line for the highlighted token data.extend( [ { "position": highlighted_index, "value": 0, "type": "hover-line", "color": "#000000", "highlight": True, }, { "position": highlighted_index, "value": 1, "type": "hover-line", "color": "#000000", "highlight": True, }, ] ) return pd.DataFrame(data) except Exception as e: logging.error(f"Error creating line plot: {e}", exc_info=True) # Return an empty DataFrame with the expected columns return pd.DataFrame(columns=["position", "value", "type", "highlight", "color"]) def create_pyplot(tokens, eval_points, highlighted_index=-1): """Create a pyplot of token values with optional highlighting.""" plt.style.use("ggplot") # Set theme to grid paper fig = plt.figure(figsize=(10, 6)) # Set figure size ax = fig.add_subplot(111) x = [0] y = [0] for i, (v, b, s) in eval_points: x.append(i + 1) y.append(v) ax.plot(x, y, "o--", color="#4698cf") for i, (v, b, s) in eval_points: if not b: continue color = "green" if s else "red" ax.plot(i + 1, v, "o", color=color) if i >= len(tokens): print(f"Token index {i} is out of bounds for n_tokens: {len(tokens)}") ax.annotate(f"{tokens[i]}", (i + 1, v), textcoords="offset points", xytext=(0, 10), ha="center") if highlighted_index >= 0: # Add light vertical line for the highlighted token from 0 to 1 ax.axvline(x=highlighted_index + 1, color="#ff9900", linestyle="--", ymin=0, ymax=1) ax.set_title("Buzz Confidence") ax.set_xlabel("Token Index") ax.set_ylabel("Confidence") ax.set_xticks(x) ax.set_xticklabels(x) return fig def create_scatter_pyplot(token_positions, scores): """Create a scatter plot of token positions and scores.""" plt.style.use("ggplot") fig = plt.figure(figsize=(10, 6)) ax = fig.add_subplot(111) counts = Counter(zip(token_positions, scores)) X = [] Y = [] S = [] for (pos, score), size in counts.items(): X.append(pos) Y.append(score) S.append(size * 20) ax.scatter(X, Y, color="#4698cf", s=S) return fig def update_plot(highlighted_index, state): """Update the plot when a token is hovered; add a vertical line on the plot.""" try: if not state or state == "{}": logging.warning("Empty state provided to update_plot") return pd.DataFrame() highlighted_index = int(highlighted_index) if highlighted_index else None logging.info(f"Update plot triggered with token index: {highlighted_index}") data = json.loads(state) tokens = data.get("tokens", []) values = data.get("values", []) if not tokens or not values: logging.warning("No tokens or values found in state") return pd.DataFrame() # Create updated plot with highlighting of the token point # plot_data = create_line_plot(values, highlighted_index) plot_data = create_pyplot(tokens, values, highlighted_index) return plot_data except Exception as e: logging.error(f"Error updating plot: {e}") return pd.DataFrame()