File size: 3,891 Bytes
e22d4b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a31a6b
e22d4b7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import pandas as pd
from rank_bm25 import BM25Okapi

from llm import OpenAILLM

def cosine_similarity(vector1, vector2):
    return np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))

class QuestionAnswerer:
    def __init__(self, docs, embedding_model, llm=OpenAILLM('gpt-3.5-turbo'), cross_encoder=None):
        self.docs = docs
        self.bm25 = BM25Okapi([c.split(" ") for c in self.docs.chunks.values[1:]])
        self.embedding_model = embedding_model
        self.llm = llm
        self.cross_encoder = cross_encoder

    def sim_search(self, query, n=10, use_hyde=False, use_dot_product=False):
        if use_hyde:
            generated_doc = self._get_generated_doc(query)
            print("generated document (hyde): \n", generated_doc)
            embedding = self.embedding_model.encode(generated_doc)
        else: 
            embedding = self.embedding_model.encode(query)

        if use_dot_product: 
            similarities = self.docs['embeddings'].apply(lambda x: np.dot(x, embedding))
        else:
            similarities = self.docs['embeddings'].apply(lambda x: cosine_similarity(x, embedding))

        self.docs['similarities'] = similarities
        return self.docs.sort_values('similarities', ascending=False).head(n)
    
    def sim_search_rerank(self, query, n=10, sim_search_n=100, **kwargs):
            search_results = self.sim_search(query, n=sim_search_n, use_hyde=False, **kwargs)
            reranked_results = self.rerank(search_results, query)
            return reranked_results[:n]
    
    def sim_search_bm25(self, query, n=10):
            tokenized_query = query.split(" ")
            doc_scores = self.bm25.get_scores(tokenized_query)
            self.docs['bm25'] = np.insert(doc_scores, 0, 0) #hack because I have to remove the first item, because I cannot split it
            result = self.docs.sort_values('bm25', ascending=False)[:n]
            return result

    def _create_prompt(self, context, question):
        return f"""
        Context information is below.
        ---------------------
        {context}
        ---------------------
        Given the context information and not prior knowledge, answer the query.
        Query: {question}
        Answer: \
        """

    def _get_generated_doc(self, question):
        prompt = f"""Write a guideline section in German answering the question below
        ---------------------
        Question: {question}
        ---------------------
        Answer: \
        """
        system_prompt = "You are an experienced radiologist answering medical questions. You answer only in German."
        return self.llm.get_response(system_prompt, prompt)
    
    
    def rerank(self, docs, query): 
        inp = [[query, doc.chunks] for doc in docs.itertuples()]
        cross_scores = self.cross_encoder.predict(inp) if self.cross_encoder else []
        docs['cross_score'] = cross_scores
        return docs.sort_values('cross_score', ascending=False)
    
    def answer_question(self, question, n=3, use_hyde=False, use_reranker=False, use_bm25=False):
        if use_reranker and use_hyde:
            print('reranking together with hyde is not supported yet')
        if use_reranker: 
            search_results = self.sim_search_rerank(question, n=n)
        if use_bm25:
            search_results = self.sim_search_bm25(question, n=n)
        else:
            search_results = self.sim_search(question, n=n, use_hyde=use_hyde)
        
        context = "\n\n".join(search_results['chunks'])

        prompt = self._create_prompt(context, question)

        system_prompt = "You are a helpful assistant answering questions in German. You answer only in German. If you do not know an answer you say it. You do not fabricate answers."

        return self.llm.get_response(system_prompt, prompt, temperature=0)