File size: 2,154 Bytes
17283b0
 
 
 
3373861
17283b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import textwrap
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import torch
#from tabulate import tabulate
import time

model_bi_encoder = "msmarco-distilbert-base-tas-b"
model_cross_encoder = "cross-encoder/ms-marco-MiniLM-L-12-v2"

bi_encoder = SentenceTransformer(model_bi_encoder)
bi_encoder.max_seq_length = 512

cross_encoder = CrossEncoder(model_cross_encoder)

top_k = 20


def get_corpus(passages):
    
    if "corpus.pt" not in os.listdir(os.getcwd()):
        corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
        torch.save(corpus_embeddings, "corpus.pt")
    else:
        corpus_embeddings = torch.load("corpus.pt")
    
    return corpus_embeddings


def search(query, passages):
    
    corpus_embeddings = get_corpus(passages)
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)

    be = time.process_time()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    #print("Time taken by Bi-encoder:" + str(time.process_time() - be))

    hits = hits[0]
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]

    ce = time.process_time()
    cross_scores = cross_encoder.predict(cross_inp)
    #print("Time taken by Cross-encoder:" + str(time.process_time() - ce))

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]


    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    result_table = list()
    for hit in hits[0:5]:
        ans = "{}".format(passages[hit['corpus_id']].replace("\n", " "))
        #print(ans)
        cs = "{}".format(hit['cross-score'])
        #print(cs)
        sc = "{}".format(hit['score'])
        #print(sc)
        wrapper = textwrap.TextWrapper(width=50)
        ans = wrapper.fill(text=ans)
        result_table.append([ans,str(cs),str(sc)])
        
    return result_table

    #print(tabulate(result_table, headers=["Answer", "Cross-encoder score", "Bi-encoder score"], tablefmt="fancy_grid", maxcolwidths=[None, None, None]))