Spaces:
Runtime error
Runtime error
File size: 1,736 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import sys
import tqdm
import ujson
import random
from argparse import ArgumentParser
from collections import OrderedDict
from colbert.utils.utils import print_message, file_tqdm
def main(args):
qid_to_file_idx = {}
for qrels_idx, qrels in enumerate(args.all_queries):
with open(qrels) as f:
for line in f:
qid, *_ = line.strip().split('\t')
qid = int(qid)
assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
qid_to_file_idx[qid] = qrels_idx
all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]
assert all(not os.path.exists(path) for path in all_outputs_paths)
all_outputs = [open(path, 'w') for path in all_outputs_paths]
with open(args.ranking) as f:
print_message(f"#> Loading ranked lists from {f.name} ..")
last_file_idx = -1
for line in file_tqdm(f):
qid, *_ = line.strip().split('\t')
file_idx = qid_to_file_idx[int(qid)]
if file_idx != last_file_idx:
print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")
last_file_idx = file_idx
all_outputs[file_idx].write(line)
print()
for f in all_outputs:
print(f.name)
f.close()
print("#> Done!")
if __name__ == "__main__":
random.seed(12345)
parser = ArgumentParser(description='.')
# Input Arguments
parser.add_argument('--ranking', dest='ranking', required=True, type=str)
parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+')
args = parser.parse_args()
main(args)
|