Spaces:
Runtime error
Runtime error
File size: 4,206 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
Evaluate MS MARCO Passages ranking.
"""
import os
import math
import tqdm
import ujson
import random
from argparse import ArgumentParser
from collections import defaultdict
from colbert.utils.utils import print_message, file_tqdm
def main(args):
qid2positives = defaultdict(list)
qid2ranking = defaultdict(list)
qid2mrr = {}
qid2recall = {depth: {} for depth in [50, 200, 1000, 5000, 10000]}
with open(args.qrels) as f:
print_message(f"#> Loading QRELs from {args.qrels} ..")
for line in file_tqdm(f):
qid, _, pid, label = map(int, line.strip().split())
assert label == 1
qid2positives[qid].append(pid)
with open(args.ranking) as f:
print_message(f"#> Loading ranked lists from {args.ranking} ..")
for line in file_tqdm(f):
qid, pid, rank, *score = line.strip().split('\t')
qid, pid, rank = int(qid), int(pid), int(rank)
if len(score) > 0:
assert len(score) == 1
score = float(score[0])
else:
score = None
qid2ranking[qid].append((rank, pid, score))
assert set.issubset(set(qid2ranking.keys()), set(qid2positives.keys()))
num_judged_queries = len(qid2positives)
num_ranked_queries = len(qid2ranking)
if num_judged_queries != num_ranked_queries:
print()
print_message("#> [WARNING] num_judged_queries != num_ranked_queries")
print_message(f"#> {num_judged_queries} != {num_ranked_queries}")
print()
print_message(f"#> Computing MRR@10 for {num_judged_queries} queries.")
for qid in tqdm.tqdm(qid2positives):
ranking = qid2ranking[qid]
positives = qid2positives[qid]
for rank, (_, pid, _) in enumerate(ranking):
rank = rank + 1 # 1-indexed
if pid in positives:
if rank <= 10:
qid2mrr[qid] = 1.0 / rank
break
for rank, (_, pid, _) in enumerate(ranking):
rank = rank + 1 # 1-indexed
if pid in positives:
for depth in qid2recall:
if rank <= depth:
qid2recall[depth][qid] = qid2recall[depth].get(qid, 0) + 1.0 / len(positives)
assert len(qid2mrr) <= num_ranked_queries, (len(qid2mrr), num_ranked_queries)
print()
mrr_10_sum = sum(qid2mrr.values())
print_message(f"#> MRR@10 = {mrr_10_sum / num_judged_queries}")
print_message(f"#> MRR@10 (only for ranked queries) = {mrr_10_sum / num_ranked_queries}")
print()
for depth in qid2recall:
assert len(qid2recall[depth]) <= num_ranked_queries, (len(qid2recall[depth]), num_ranked_queries)
print()
metric_sum = sum(qid2recall[depth].values())
print_message(f"#> Recall@{depth} = {metric_sum / num_judged_queries}")
print_message(f"#> Recall@{depth} (only for ranked queries) = {metric_sum / num_ranked_queries}")
print()
if args.annotate:
print_message(f"#> Writing annotations to {args.output} ..")
with open(args.output, 'w') as f:
for qid in tqdm.tqdm(qid2positives):
ranking = qid2ranking[qid]
positives = qid2positives[qid]
for rank, (_, pid, score) in enumerate(ranking):
rank = rank + 1 # 1-indexed
label = int(pid in positives)
line = [qid, pid, rank, score, label]
line = [x for x in line if x is not None]
line = '\t'.join(map(str, line)) + '\n'
f.write(line)
if __name__ == "__main__":
parser = ArgumentParser(description="msmarco_passages.")
# Input Arguments.
parser.add_argument('--qrels', dest='qrels', required=True, type=str)
parser.add_argument('--ranking', dest='ranking', required=True, type=str)
parser.add_argument('--annotate', dest='annotate', default=False, action='store_true')
args = parser.parse_args()
if args.annotate:
args.output = f'{args.ranking}.annotated'
assert not os.path.exists(args.output), args.output
main(args)
|