DocumentQA / CrossEncoder /cross_encoder.py
Epoching's picture
Update CrossEncoder/cross_encoder.py
ead4891
raw
history blame
3.98 kB
# Copyright (c) 2022, Lawrence Livermore National Security, LLC.
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964
# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception
from sentence_transformers.cross_encoder import CrossEncoder as CE
import numpy as np
from typing import List, Dict, Tuple
import json
from collections import defaultdict
import os
class CrossEncoder:
def __init__(self,
model_path: str = None,
max_length: int = None,
**kwargs):
if max_length != None:
self.model = CE(model_path, max_length = max_length, **kwargs)
self.model = CE(model_path, **kwargs)
def predict(self,
sentences: List[Tuple[str, str]],
batch_size: int = 32,
show_progress_bar: bool = False) -> List[float]:
return self.model.predict(sentences = sentences,
batch_size = batch_size,
show_progress_bar = show_progress_bar)
class CERank:
def __init__(self, model, batch_size: int =128, **kwargs):
self.cross_encoder = model
self.batch_size = batch_size
def flatten_examples(self, contexts: Dict[str, Dict], question: str):
text_pairs, pair_ids = [], []
for context_id, context in contexts.items():
pair_ids.append(['question_0', context_id])
text_pairs.append([question, context['text']])
return text_pairs, pair_ids
def group_questionrank(self, pair_ids, rank_scores):
unsorted = defaultdict(list)
for pair, score in zip(pair_ids, rank_scores):
query_id, paragraph_id = pair[0], pair[1]
unsorted[query_id].append((paragraph_id, score))
return unsorted
def get_rankings(self, pair_ids, rank_scores, text_pairs):
unsorted_ranks = self.group_questionrank(pair_ids, rank_scores)
rankings = defaultdict(dict)
for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()):
sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True)
sorted_ranks, scores = list(zip(*sort_ranks))
rankings[query_id]['text'] = text_pairs[idx][0]
rankings[query_id]['scores'] = list(scores)
rankings[query_id]['ranks'] = list(sorted_ranks)
return rankings
def rank(self,
contexts: Dict[str, Dict],
question: str):
text_pairs, pair_ids = self.flatten_examples(contexts, question)
rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)]
full_results = self.get_rankings(pair_ids, rank_scores, text_pairs)
return full_results
def get_ranked_contexts(context_json, question):
dirname = 'examples'
model_path = 'ms-marco-electra-base'
max_length = 512
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
ranker = CERank(cross_encoder)
with open(context_json, 'r') as fin:
contexts = json.load(fin)
rankings = ranker.rank(contexts, question)
with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout:
json.dump(rankings, fout)
def get_ranked_contexts_in_memory(contexts, question):
dirname = 'examples'
model_path = 'ms-marco-electra-base'
max_length = 512
# Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
ranker = CERank(cross_encoder)
rankings = ranker.rank(contexts, question)
return rankings