OpenFactCheck-Prerelease
/
src
/openfactcheck
/solvers
/rarr_solvers
/rarr_utils
/evidence_selection.py
import itertools | |
from typing import Any, Dict, List | |
import torch | |
from sentence_transformers import CrossEncoder | |
PASSAGE_RANKER = CrossEncoder( | |
"cross-encoder/ms-marco-MiniLM-L-6-v2", | |
max_length=512, | |
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
) | |
def compute_score_matrix( | |
questions: List[str], evidences: List[str] | |
) -> List[List[float]]: | |
"""Scores the relevance of all evidence against all questions using a CrossEncoder. | |
Args: | |
questions: A list of unique questions. | |
evidences: A list of unique evidences. | |
Returns: | |
score_matrix: A 2D list list of question X evidence relevance scores. | |
""" | |
score_matrix = [] | |
for q in questions: | |
evidence_scores = PASSAGE_RANKER.predict([(q, e) for e in evidences]).tolist() | |
score_matrix.append(evidence_scores) | |
return score_matrix | |
def question_coverage_objective_fn( | |
score_matrix: List[List[float]], evidence_indices: List[int] | |
) -> float: | |
"""Given (query, evidence) scores and a subset of evidence, return the coverage. | |
Given all pairwise query and evidence scores, and a subset of the evidence | |
specified by indices, return a value indicating how well this subset of evidence | |
covers (i.e., helps answer) all questions. | |
Args: | |
score_matrix: A 2D list list of question X evidence relevance scores. | |
evidence_indicies: A subset of the evidence to to get the coverage score of. | |
Returns: | |
total: The coverage we would get by using the subset of evidence in | |
`evidence_indices` over all questions. | |
""" | |
# Compute sum_{question q} max_{selected evidence e} score(q, e). | |
# This encourages all questions to be explained by at least one evidence. | |
total = 0.0 | |
for scores_for_question in score_matrix: | |
total += max(scores_for_question[j] for j in evidence_indices) | |
return total | |
def select_evidences( | |
example: Dict[str, Any], max_selected: int = 5, prefer_fewer: bool = False | |
) -> List[Dict[str, Any]]: | |
"""Selects the set of evidence that maximizes information converage over the claim. | |
Args: | |
example: The result of running the editing pipeline on one claim. | |
max_selected: Maximum number of evidences to select. | |
prefer_fewer: If True and the maximum objective value can be achieved by | |
fewer evidences than `max_selected`, prefer selecting fewer evidences. | |
Returns: | |
selected_evidences: Selected evidences that serve as the attribution report. | |
""" | |
questions = sorted(set(example["questions"])) | |
evidences = sorted(set(e["text"] for e in example["revisions"][0]["evidences"])) | |
num_evidences = len(evidences) | |
if not num_evidences: | |
return [] | |
score_matrix = compute_score_matrix(questions, evidences) | |
best_combo = tuple() | |
best_objective_value = float("-inf") | |
max_selected = min(max_selected, num_evidences) | |
min_selected = 1 if prefer_fewer else max_selected | |
for num_selected in range(min_selected, max_selected + 1): | |
for combo in itertools.combinations(range(num_evidences), num_selected): | |
objective_value = question_coverage_objective_fn(score_matrix, combo) | |
if objective_value > best_objective_value: | |
best_combo = combo | |
best_objective_value = objective_value | |
selected_evidences = [{"text": evidences[idx]} for idx in best_combo] | |
return selected_evidences | |