File size: 3,469 Bytes
8360ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import itertools
from typing import Any, Dict, List

import torch
from sentence_transformers import CrossEncoder

PASSAGE_RANKER = CrossEncoder(
    "cross-encoder/ms-marco-MiniLM-L-6-v2",
    max_length=512,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)


def compute_score_matrix(
    questions: List[str], evidences: List[str]
) -> List[List[float]]:
    """Scores the relevance of all evidence against all questions using a CrossEncoder.

    Args:
        questions: A list of unique questions.
        evidences: A list of unique evidences.
    Returns:
        score_matrix: A 2D list list of question X evidence relevance scores.
    """
    score_matrix = []
    for q in questions:
        evidence_scores = PASSAGE_RANKER.predict([(q, e) for e in evidences]).tolist()
        score_matrix.append(evidence_scores)
    return score_matrix


def question_coverage_objective_fn(
    score_matrix: List[List[float]], evidence_indices: List[int]
) -> float:
    """Given (query, evidence) scores and a subset of evidence, return the coverage.

    Given all pairwise query and evidence scores, and a subset of the evidence
    specified by indices, return a value indicating how well this subset of evidence
    covers (i.e., helps answer) all questions.

    Args:
        score_matrix: A 2D list list of question X evidence relevance scores.
        evidence_indicies: A subset of the evidence to to get the coverage score of.
    Returns:
        total: The coverage we would get by using the subset of evidence in
            `evidence_indices` over all questions.
    """
    # Compute sum_{question q} max_{selected evidence e} score(q, e).
    # This encourages all questions to be explained by at least one evidence.
    total = 0.0
    for scores_for_question in score_matrix:
        total += max(scores_for_question[j] for j in evidence_indices)
    return total


def select_evidences(
    example: Dict[str, Any], max_selected: int = 5, prefer_fewer: bool = False
) -> List[Dict[str, Any]]:
    """Selects the set of evidence that maximizes information converage over the claim.

    Args:
        example: The result of running the editing pipeline on one claim.
        max_selected: Maximum number of evidences to select.
        prefer_fewer: If True and the maximum objective value can be achieved by
            fewer evidences than `max_selected`, prefer selecting fewer evidences.
    Returns:
        selected_evidences: Selected evidences that serve as the attribution report.
    """
    questions = sorted(set(example["questions"]))
    evidences = sorted(set(e["text"] for e in example["revisions"][0]["evidences"]))
    num_evidences = len(evidences)
    if not num_evidences:
        return []

    score_matrix = compute_score_matrix(questions, evidences)

    best_combo = tuple()
    best_objective_value = float("-inf")
    max_selected = min(max_selected, num_evidences)
    min_selected = 1 if prefer_fewer else max_selected
    for num_selected in range(min_selected, max_selected + 1):
        for combo in itertools.combinations(range(num_evidences), num_selected):
            objective_value = question_coverage_objective_fn(score_matrix, combo)
            if objective_value > best_objective_value:
                best_combo = combo
                best_objective_value = objective_value

    selected_evidences = [{"text": evidences[idx]} for idx in best_combo]
    return selected_evidences