"""Utility metrics.""" import sqlglot from rich.console import Console from sqlglot import parse_one console = Console(soft_wrap=True) def correct_casing(sql: str) -> str: """Correct casing of SQL.""" parse: sqlglot.expressions.Expression = parse_one(sql, read="sqlite") return parse.sql() def prec_recall_f1(gold: set, pred: set) -> dict[str, float]: """Compute precision, recall and F1 score.""" prec = len(gold.intersection(pred)) / len(pred) if pred else 0.0 recall = len(gold.intersection(pred)) / len(gold) if gold else 0.0 f1 = 2 * prec * recall / (prec + recall) if prec + recall else 0.0 return {"prec": prec, "recall": recall, "f1": f1} def edit_distance(s1: str, s2: str) -> int: """Compute edit distance between two strings.""" # Make sure s1 is the shorter string if len(s1) > len(s2): s1, s2 = s2, s1 distances: list[int] = list(range(len(s1) + 1)) for i2, c2 in enumerate(s2): distances_ = [i2 + 1] for i1, c1 in enumerate(s1): if c1 == c2: distances_.append(distances[i1]) else: distances_.append( 1 + min((distances[i1], distances[i1 + 1], distances_[-1])) ) distances = distances_ return distances[-1]