File size: 4,886 Bytes
7f46a81
 
a05cb39
 
7f46a81
eaf872d
 
7f46a81
eaf872d
f26592e
7f46a81
eaf872d
8ca1a6a
eaf872d
 
 
7f46a81
 
a05cb39
eaf872d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd0ad1b
 
 
 
 
 
 
 
 
 
 
eaf872d
 
 
 
fd0ad1b
 
8ca1a6a
eaf872d
 
fd0ad1b
 
8ca1a6a
fd0ad1b
eaf872d
 
 
 
 
 
7f46a81
ea58cde
a05cb39
 
 
 
 
eaf872d
 
 
 
 
 
 
 
a05cb39
 
 
 
8ca1a6a
a05cb39
eaf872d
 
 
 
 
8ca1a6a
eaf872d
a05cb39
7f46a81
 
eaf872d
 
7f46a81
 
 
7ff5239
eaf872d
 
39e2176
eaf872d
39e2176
a05cb39
7f46a81
8ca1a6a
7f46a81
eaf872d
 
 
 
 
8ca1a6a
eaf872d
 
f26592e
a05cb39
 
eaf872d
 
 
a05cb39
 
 
eaf872d
a05cb39
eaf872d
 
 
 
 
 
 
fd0ad1b
 
eaf872d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import requests
import json


class VectaraQuery():
    def __init__(self, api_key: str, corpus_keys: list[str], prompt_name: str = None):
        self.corpus_keys = corpus_keys
        self.api_key = api_key
        self.prompt_name = prompt_name if prompt_name else "vectara-summary-ext-24-05-sml"
        self.conv_id = None

    
    def get_body(self, query_str: str, response_lang: str, stream: False):
        corpora_list = [{
                'corpus_key': corpus_key, 'lexical_interpolation': 0.005
            } for corpus_key in self.corpus_keys
        ]

        return {
            'query': query_str,
            'search':
            {
                'corpora': corpora_list,
                'offset': 0,
                'limit': 50,
                'context_configuration':
                {
                    'sentences_before': 2,
                    'sentences_after': 2,
                    'start_tag': "%START_SNIPPET%",
                    'end_tag': "%END_SNIPPET%",
                },
                'reranker':
                {
                    "type": "chain",
                    "rerankers": [
                        {
                            "type": "customer_reranker",
                            "reranker_name": "Rerank_Multilingual_v1"
                        },
                        {
                            "type": "mmr",
                            "diversity_bias": 0.05
                        }
                    ]
                },
            },
            'generation':
            {
                'generation_preset_name': self.prompt_name,
                'max_used_search_results': 7,
                'response_language': response_lang,
                'citations':
                {
                    'style': 'markdown',
                    'url_pattern': '{doc.url}'
                },
                'enable_factual_consistency_score': True
            },
            'chat':
            {
                'store': True
            },
            'stream_response': stream
        }
    

    def get_headers(self):
        return {
            "Content-Type": "application/json",
            "Accept": "application/json",
            "x-api-key": self.api_key,
            "grpc-timeout": "60S"
        }
    
    def get_stream_headers(self):
        return {
            "Content-Type": "application/json",
            "Accept": "text/event-stream",
            "x-api-key": self.api_key,
            "grpc-timeout": "60S"
        }

    def submit_query(self, query_str: str, language: str):

        if self.conv_id:
            endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
        else:
            endpoint = "https://api.vectara.io/v2/chats"

        body = self.get_body(query_str, language, stream=False)
        response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())

        if response.status_code != 200:
            print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
            if response.status_code == 429:
                return "Sorry, Vectara chat turns exceeds plan limit."
            return "Sorry, something went wrong in my brain. Please try again later."

        res = response.json()

        if self.conv_id is None:
            self.conv_id = res['chat_id']

        summary = res['answer']
        
        return summary

    def submit_query_streaming(self, query_str: str, language: str):

        if self.conv_id:
            endpoint = f"https://api.vectara.io/v2/chats/{self.conv_id}/turns"
        else:
            endpoint = "https://api.vectara.io/v2/chats"

        body = self.get_body(query_str, language, stream=True)

        response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_stream_headers(), stream=True) 

        if response.status_code != 200:
            print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
            if response.status_code == 429:
                return "Sorry, Vectara chat turns exceeds plan limit."
            return "Sorry, something went wrong in my brain. Please try again later."        

        chunks = []
        for line in response.iter_lines():
            line = line.decode('utf-8')
            if line:  # filter out keep-alive new lines
                key, value = line.split(':', 1)
                if key == 'data':
                    line = json.loads(value)
                    if line['type'] == 'generation_chunk':
                        chunk = line['generation_chunk']
                        chunks.append(chunk)
                        yield chunk
                    elif line['type'] == 'chat_info':
                        self.conv_id = line['chat_id']

        return ''.join(chunks)