File size: 3,984 Bytes
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ead4891
c14d9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ead4891
c14d9ad
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) 2022, Lawrence Livermore National Security, LLC. 
# All rights reserved.
# See the top-level LICENSE and NOTICE files for details.
# LLNL-CODE-838964

# SPDX-License-Identifier: Apache-2.0-with-LLVM-exception

from sentence_transformers.cross_encoder import CrossEncoder as CE
import numpy as np 
from typing import List, Dict, Tuple 
import json
from collections import defaultdict
import os


class CrossEncoder:
    def __init__(self, 
                 model_path: str = None, 
                 max_length: int = None, 
                 **kwargs):

        if max_length != None:
            self.model = CE(model_path, max_length = max_length, **kwargs)

        self.model = CE(model_path, **kwargs)


    def predict(self, 
                sentences: List[Tuple[str, str]], 
                batch_size: int = 32,
                show_progress_bar: bool = False) -> List[float]:

        return self.model.predict(sentences = sentences, 
                                  batch_size = batch_size, 
                                  show_progress_bar = show_progress_bar)


class CERank:

    def __init__(self, model, batch_size: int =128, **kwargs):
        self.cross_encoder = model 
        self.batch_size = batch_size 


    def flatten_examples(self, contexts: Dict[str, Dict], question: str):
      
        text_pairs, pair_ids = [], []
        for context_id, context in contexts.items():
            pair_ids.append(['question_0', context_id])
            text_pairs.append([question, context['text']])

        return text_pairs, pair_ids

    def group_questionrank(self, pair_ids, rank_scores):

        unsorted = defaultdict(list)
        for pair, score in zip(pair_ids, rank_scores):
            query_id, paragraph_id = pair[0], pair[1]
            unsorted[query_id].append((paragraph_id, score))


        return unsorted 

    def get_rankings(self, pair_ids, rank_scores, text_pairs):

        unsorted_ranks = self.group_questionrank(pair_ids, rank_scores)
        rankings = defaultdict(dict)

        for idx, (query_id, ranks) in enumerate(unsorted_ranks.items()):
            sort_ranks = sorted(ranks, key = lambda item: item[1], reverse = True)
            sorted_ranks, scores = list(zip(*sort_ranks))
            rankings[query_id]['text'] = text_pairs[idx][0]
            rankings[query_id]['scores'] = list(scores)
            rankings[query_id]['ranks'] = list(sorted_ranks)

        return rankings
      

    def rank(self, 
             contexts: Dict[str, Dict], 
             question: str):


        text_pairs, pair_ids = self.flatten_examples(contexts, question)
        rank_scores = [float(score) for score in self.cross_encoder.predict(text_pairs, batch_size = self.batch_size)]
        full_results = self.get_rankings(pair_ids, rank_scores, text_pairs)

        return full_results



def get_ranked_contexts(context_json, question):

    dirname = 'examples'
    model_path = 'ms-marco-electra-base'
    max_length =  512
    
    # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
    cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
    ranker = CERank(cross_encoder)
    
    with open(context_json, 'r') as fin:
        contexts = json.load(fin)

    rankings = ranker.rank(contexts, question)
    
    with open('ranked_{0}.json'.format(context_json[:-5]), 'w') as fout:
        json.dump(rankings, fout)
        
def get_ranked_contexts_in_memory(contexts, question):

    dirname = 'examples'
    model_path = 'ms-marco-electra-base'
    max_length =  512
    
    # Can't use use_fast (fast tokenizers) while gradio is running, causes conflict with tokenizer multiprocessing/parallelism.
    cross_encoder = CrossEncoder(model_path, max_length, tokenizer_args={'use_fast':False})
    ranker = CERank(cross_encoder)
    
    rankings = ranker.rank(contexts, question)
    
    return rankings