ir_chinese_medqa / utility /rankings /split_by_queries.py
欧卫
'add_app_files'
58627fa
raw
history blame
1.74 kB
import os
import sys
import tqdm
import ujson
import random
from argparse import ArgumentParser
from collections import OrderedDict
from colbert.utils.utils import print_message, file_tqdm
def main(args):
qid_to_file_idx = {}
for qrels_idx, qrels in enumerate(args.all_queries):
with open(qrels) as f:
for line in f:
qid, *_ = line.strip().split('\t')
qid = int(qid)
assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
qid_to_file_idx[qid] = qrels_idx
all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]
assert all(not os.path.exists(path) for path in all_outputs_paths)
all_outputs = [open(path, 'w') for path in all_outputs_paths]
with open(args.ranking) as f:
print_message(f"#> Loading ranked lists from {f.name} ..")
last_file_idx = -1
for line in file_tqdm(f):
qid, *_ = line.strip().split('\t')
file_idx = qid_to_file_idx[int(qid)]
if file_idx != last_file_idx:
print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")
last_file_idx = file_idx
all_outputs[file_idx].write(line)
print()
for f in all_outputs:
print(f.name)
f.close()
print("#> Done!")
if __name__ == "__main__":
random.seed(12345)
parser = ArgumentParser(description='.')
# Input Arguments
parser.add_argument('--ranking', dest='ranking', required=True, type=str)
parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+')
args = parser.parse_args()
main(args)