# Copyright (c) 2022, Lawrence Livermore National Security, LLC. # All rights reserved. # See the top-level LICENSE and NOTICE files for details. # LLNL-CODE-838964 # SPDX-License-Identifier: Apache-2.0-with-LLVM-exception from sentence_transformers.cross_encoder import CrossEncoder as CE import numpy as np from typing import List, Dict, Tuple import json from collections import defaultdict import os class CrossEncoder: def __init__(self, model_path: str = None, max_length: int = None, **kwargs): if max_length != None: self.model = CE(model_path, max_length = max_length, **kwargs) self.model = CE(model_path, **kwargs) def predict(self, sentences: List[Tuple[str, str]], batch_size: int = 32, show_progress_bar: bool = False) -> List[float]: return self.model.predict(sentences = sentences, batch_size = batch_size, show_progress_bar = show_progress_bar) class CERank: def __init__(self, model, batch_size: int =128, **kwargs): self.cross_encoder = model self.batch_size = batch_size def flatten_examples(self, contexts: Dict[str, Dict], question: str): text_pairs, pair_ids = [], [] for context_id, context in contexts.items(): pair_ids.append(['question_0', context_id]) text_pairs.append([question, context['text']]) return text_pairs, pair_ids def group_questionrank(self, pair_ids, rank_scores): unsorted = defaultdict(list) for pair, score in zip(pair_ids, rank_scores): query_id, paragraph_id = pair[0], pair[1] unsorted[query_id].append((paragraph_id, score)) return unsorted def get_rankings(self, pair_ids, rank_scores, text_pairs): unsorted_ranks = self.group_questionrank(pair_ids, rank_scores) rankings = defaultdict(dict) for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()): sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True) sorted_ranks, scores = list(zip(*sort_ranks)) rankings[query_id]['text'] = text_pairs[idx][0] rankings[query_id]['scores'] = list(scores) rankings[query_id]['ranks'] = list(sorted_ranks) return rankings def rank(self, contexts: Dict[str, Dict], question: str): text_pairs, pair_ids = self.flatten_examples(contexts, question) rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)] full_results = self.get_rankings(pair_ids, rank_scores, text_pairs) return full_results def get_ranked_contexts(context_json, question): dirname = 'examples' model_path = 'ms-marco-electra-base' max_length = 512 # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) ranker = CERank(cross_encoder) with open(context_json, 'r') as fin: contexts = json.load(fin) rankings = ranker.rank(contexts, question) with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout: json.dump(rankings, fout) def get_ranked_contexts_in_memory(contexts, question): dirname = 'examples' model_path = 'ms-marco-electra-base' max_length = 512 # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism. cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False}) ranker = CERank(cross_encoder) rankings = ranker.rank(contexts, question) return rankings