|
import argparse |
|
import json |
|
import os |
|
import time |
|
import numpy as np |
|
import nltk |
|
from rank_bm25 import BM25Okapi |
|
|
|
|
|
def combine_all_sentences(knowledge_file): |
|
sentences, urls = [], [] |
|
|
|
with open(knowledge_file, "r", encoding="utf-8") as json_file: |
|
for i, line in enumerate(json_file): |
|
data = json.loads(line) |
|
sentences.extend(data["url2text"]) |
|
urls.extend([data["url"] for i in range(len(data["url2text"]))]) |
|
return sentences, urls, i + 1 |
|
|
|
|
|
def retrieve_top_k_sentences(query, document, urls, top_k): |
|
tokenized_docs = [nltk.word_tokenize(doc) for doc in document] |
|
bm25 = BM25Okapi(tokenized_docs) |
|
scores = bm25.get_scores(nltk.word_tokenize(query)) |
|
top_k_idx = np.argsort(scores)[::-1][:top_k] |
|
|
|
return [document[i] for i in top_k_idx], [urls[i] for i in top_k_idx] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Get top 100 sentences with BM25 in the knowledge store." |
|
) |
|
parser.add_argument( |
|
"-k", |
|
"--knowledge_store_dir", |
|
type=str, |
|
default="data_store/output_dev", |
|
help="The path of the knowledge_store_dir containing json files with all the retrieved sentences.", |
|
) |
|
parser.add_argument( |
|
"-c", |
|
"--claim_file", |
|
type=str, |
|
default="data/dev.json", |
|
help="The path of the file that stores the claim.", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--json_output", |
|
type=str, |
|
default="data_store/dev_top_k.json", |
|
help="The output dir for JSON files to save the top 100 sentences for each claim.", |
|
) |
|
parser.add_argument( |
|
"--top_k", |
|
default=100, |
|
type=int, |
|
help="How many documents should we pick out with BM25.", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--start", |
|
type=int, |
|
default=0, |
|
help="Staring index of the files to process.", |
|
) |
|
parser.add_argument( |
|
"-e", "--end", type=int, default=-1, help="End index of the files to process." |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
with open(args.claim_file, "r", encoding="utf-8") as json_file: |
|
target_examples = json.load(json_file) |
|
|
|
if args.end == -1: |
|
args.end = len(os.listdir(args.knowledge_store_dir)) |
|
print(args.end) |
|
|
|
files_to_process = list(range(args.start, args.end)) |
|
total = len(files_to_process) |
|
|
|
with open(args.json_output, "w", encoding="utf-8") as output_json: |
|
done = 0 |
|
for idx, example in enumerate(target_examples): |
|
|
|
if idx in files_to_process: |
|
print(f"Processing claim {idx}... Progress: {done + 1} / {total}") |
|
document_in_sentences, sentence_urls, num_urls_this_claim = ( |
|
combine_all_sentences( |
|
os.path.join(args.knowledge_store_dir, f"{idx}.json") |
|
) |
|
) |
|
|
|
print( |
|
f"Obtained {len(document_in_sentences)} sentences from {num_urls_this_claim} urls." |
|
) |
|
|
|
|
|
st = time.time() |
|
top_k_sentences, top_k_urls = retrieve_top_k_sentences( |
|
example["claim"], document_in_sentences, sentence_urls, args.top_k |
|
) |
|
print(f"Top {args.top_k} retrieved. Time elapsed: {time.time() - st}.") |
|
|
|
json_data = { |
|
"claim_id": idx, |
|
"claim": example["claim"], |
|
f"top_{args.top_k}": [ |
|
{"sentence": sent, "url": url} |
|
for sent, url in zip(top_k_sentences, top_k_urls) |
|
], |
|
} |
|
output_json.write(json.dumps(json_data, ensure_ascii=False) + "\n") |
|
done += 1 |
|
output_json.flush() |
|
|