ofermend commited on
Commit
7aa515b
·
verified ·
1 Parent(s): 1ab2d63

Update query.py

Browse files
Files changed (1) hide show
  1. query.py +124 -46
query.py CHANGED
@@ -8,31 +8,62 @@ def extract_between_tags(text, start_tag, end_tag):
8
  end_index = text.find(end_tag, start_index)
9
  return text[start_index+len(start_tag):end_index-len(end_tag)]
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class VectaraQuery():
12
  def __init__(self, api_key: str, customer_id: str, corpus_ids: list[str], prompt_name: str = None):
13
  self.customer_id = customer_id
14
  self.corpus_ids = corpus_ids
15
  self.api_key = api_key
16
- self.prompt_name = prompt_name if prompt_name else "vectara-summary-ext-v1.2.0"
17
  self.conv_id = None
18
 
19
- def submit_query(self, query_str: str):
20
  corpora_key_list = [{
21
  'customer_id': self.customer_id, 'corpus_id': corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
22
  } for corpus_id in self.corpus_ids
23
  ]
24
 
25
- endpoint = f"https://api.vectara.io/v1/query"
26
- start_tag = "%START_SNIPPET%"
27
- end_tag = "%END_SNIPPET%"
28
- headers = {
29
- "Content-Type": "application/json",
30
- "Accept": "application/json",
31
- "customer-id": self.customer_id,
32
- "x-api-key": self.api_key,
33
- "grpc-timeout": "60S"
34
- }
35
- body = {
36
  'query': [
37
  {
38
  'query': query_str,
@@ -42,8 +73,8 @@ class VectaraQuery():
42
  'context_config': {
43
  'sentences_before': 2,
44
  'sentences_after': 2,
45
- 'start_tag': start_tag,
46
- 'end_tag': end_tag,
47
  },
48
  'rerankingConfig':
49
  {
@@ -61,14 +92,27 @@ class VectaraQuery():
61
  'store': True,
62
  'conversationId': self.conv_id
63
  },
64
- # 'debug': True,
65
  }
66
  ]
67
  }
68
  ]
69
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=headers)
72
  if response.status_code != 200:
73
  print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
74
  return "Sorry, something went wrong in my brain. Please try again later."
@@ -79,9 +123,9 @@ class VectaraQuery():
79
  summary = res['responseSet'][0]['summary'][0]['text']
80
  responses = res['responseSet'][0]['response'][:top_k]
81
  docs = res['responseSet'][0]['document']
82
- chat = res['responseSet'][0]['summary'][0]['chat']
83
 
84
- if chat['status'] != None:
85
  st_code = chat['status']
86
  print(f"Chat query failed with code {st_code}")
87
  if st_code == 'RESOURCE_EXHAUSTED':
@@ -89,34 +133,68 @@ class VectaraQuery():
89
  return 'Sorry, Vectara chat turns exceeds plan limit.'
90
  return 'Sorry, something went wrong in my brain. Please try again later.'
91
 
92
- self.conv_id = res['responseSet'][0]['summary'][0]['chat']['conversationId']
 
 
 
 
 
 
 
93
 
94
- pattern = r'\[\d{1,2}\]'
95
- matches = [match.span() for match in re.finditer(pattern, summary)]
 
 
96
 
97
- # figure out unique list of references
98
- refs = []
99
- for match in matches:
100
- start, end = match
101
- response_num = int(summary[start+1:end-1])
102
- doc_num = responses[response_num-1]['documentIndex']
103
- metadata = {item['name']: item['value'] for item in docs[doc_num]['metadata']}
104
- text = extract_between_tags(responses[response_num-1]['text'], start_tag, end_tag)
105
- if 'url' in metadata.keys():
106
- url = f"{metadata['url']}#:~:text={quote(text)}"
107
- if url not in refs:
108
- refs.append(url)
109
 
110
- # replace references with markdown links
111
- refs_dict = {url:(inx+1) for inx,url in enumerate(refs)}
112
- for match in reversed(matches):
113
- start, end = match
114
- response_num = int(summary[start+1:end-1])
115
- doc_num = responses[response_num-1]['documentIndex']
116
- metadata = {item['name']: item['value'] for item in docs[doc_num]['metadata']}
117
- text = extract_between_tags(responses[response_num-1]['text'], start_tag, end_tag)
118
- url = f"{metadata['url']}#:~:text={quote(text)}"
119
- citation_inx = refs_dict[url]
120
- summary = summary[:start] + f'[\[{citation_inx}\]]({url})' + summary[end:]
 
 
 
 
 
 
121
 
122
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  end_index = text.find(end_tag, start_index)
9
  return text[start_index+len(start_tag):end_index-len(end_tag)]
10
 
11
+ class CitationNormalizer():
12
+
13
+ def __init__(self, responses, docs):
14
+ self.docs = docs
15
+ self.responses = responses
16
+ self.refs = []
17
+
18
+ def normalize_citations(self, summary):
19
+ start_tag = "%START_SNIPPET%"
20
+ end_tag = "%END_SNIPPET%"
21
+
22
+ # find all references in the summary
23
+ pattern = r'\[\d{1,2}\]'
24
+ matches = [match.span() for match in re.finditer(pattern, summary)]
25
+
26
+ # figure out unique list of references
27
+ for match in matches:
28
+ start, end = match
29
+ response_num = int(summary[start+1:end-1])
30
+ doc_num = self.responses[response_num-1]['documentIndex']
31
+ metadata = {item['name']: item['value'] for item in self.docs[doc_num]['metadata']}
32
+ text = extract_between_tags(self.responses[response_num-1]['text'], start_tag, end_tag)
33
+ if 'url' in metadata.keys():
34
+ url = f"{metadata['url']}#:~:text={quote(text)}"
35
+ if url not in self.refs:
36
+ self.refs.append(url)
37
+
38
+ # replace references with markdown links
39
+ refs_dict = {url:(inx+1) for inx,url in enumerate(self.refs)}
40
+ for match in reversed(matches):
41
+ start, end = match
42
+ response_num = int(summary[start+1:end-1])
43
+ doc_num = self.responses[response_num-1]['documentIndex']
44
+ metadata = {item['name']: item['value'] for item in self.docs[doc_num]['metadata']}
45
+ text = extract_between_tags(self.responses[response_num-1]['text'], start_tag, end_tag)
46
+ url = f"{metadata['url']}#:~:text={quote(text)}"
47
+ citation_inx = refs_dict[url]
48
+ summary = summary[:start] + f'[\[{citation_inx}\]]({url})' + summary[end:]
49
+
50
+ return summary
51
+
52
  class VectaraQuery():
53
  def __init__(self, api_key: str, customer_id: str, corpus_ids: list[str], prompt_name: str = None):
54
  self.customer_id = customer_id
55
  self.corpus_ids = corpus_ids
56
  self.api_key = api_key
57
+ self.prompt_name = prompt_name if prompt_name else "vectara-experimental-summary-ext-2023-12-11-sml"
58
  self.conv_id = None
59
 
60
+ def get_body(self, query_str: str):
61
  corpora_key_list = [{
62
  'customer_id': self.customer_id, 'corpus_id': corpus_id, 'lexical_interpolation_config': {'lambda': 0.025}
63
  } for corpus_id in self.corpus_ids
64
  ]
65
 
66
+ return {
 
 
 
 
 
 
 
 
 
 
67
  'query': [
68
  {
69
  'query': query_str,
 
73
  'context_config': {
74
  'sentences_before': 2,
75
  'sentences_after': 2,
76
+ 'start_tag': "%START_SNIPPET%",
77
+ 'end_tag': "%END_SNIPPET%",
78
  },
79
  'rerankingConfig':
80
  {
 
92
  'store': True,
93
  'conversationId': self.conv_id
94
  },
 
95
  }
96
  ]
97
  }
98
  ]
99
  }
100
+
101
+ def get_headers(self):
102
+ return {
103
+ "Content-Type": "application/json",
104
+ "Accept": "application/json",
105
+ "customer-id": self.customer_id,
106
+ "x-api-key": self.api_key,
107
+ "grpc-timeout": "60S"
108
+ }
109
+
110
+ def submit_query(self, query_str: str):
111
+
112
+ endpoint = f"https://api.vectara.io/v1/query"
113
+ body = self.get_body(query_str)
114
 
115
+ response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers())
116
  if response.status_code != 200:
117
  print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
118
  return "Sorry, something went wrong in my brain. Please try again later."
 
123
  summary = res['responseSet'][0]['summary'][0]['text']
124
  responses = res['responseSet'][0]['response'][:top_k]
125
  docs = res['responseSet'][0]['document']
126
+ chat = res['responseSet'][0]['summary'][0].get('chat', None)
127
 
128
+ if chat and chat['status'] is not None:
129
  st_code = chat['status']
130
  print(f"Chat query failed with code {st_code}")
131
  if st_code == 'RESOURCE_EXHAUSTED':
 
133
  return 'Sorry, Vectara chat turns exceeds plan limit.'
134
  return 'Sorry, something went wrong in my brain. Please try again later.'
135
 
136
+ self.conv_id = chat['conversationId'] if chat else None
137
+ summary = CitationNormalizer().normalize_citations(summary, responses, docs)
138
+ return summary
139
+
140
+ def submit_query_streaming(self, query_str: str):
141
+
142
+ endpoint = f"https://api.vectara.io/v1/stream-query"
143
+ body = self.get_body(query_str)
144
 
145
+ response = requests.post(endpoint, data=json.dumps(body), verify=True, headers=self.get_headers(), stream=True)
146
+ if response.status_code != 200:
147
+ print(f"Query failed with code {response.status_code}, reason {response.reason}, text {response.text}")
148
+ return "Sorry, something went wrong in my brain. Please try again later."
149
 
150
+ chunks = []
151
+ accumulated_text = "" # Initialize text accumulation
152
+ pattern_max_length = 50 # Example heuristic
153
+ for line in response.iter_lines():
154
+ if line: # filter out keep-alive new lines
155
+ data = json.loads(line.decode('utf-8'))
156
+ res = data['result']
 
 
 
 
 
157
 
158
+ # capture responses and docs if we get that first
159
+ response_set = res['responseSet']
160
+ if response_set is not None:
161
+ # do we have chat conv_id to update?
162
+ summary = response_set.get('summary', [])
163
+ if len(summary)>0:
164
+ chat = summary[0].get('chat', None)
165
+ if chat and chat.get('status', None):
166
+ st_code = chat['status']
167
+ print(f"Chat query failed with code {st_code}")
168
+ if st_code == 'RESOURCE_EXHAUSTED':
169
+ self.conv_id = None
170
+ return 'Sorry, Vectara chat turns exceeds plan limit.'
171
+ return 'Sorry, something went wrong in my brain. Please try again later.'
172
+ conv_id = chat.get('conversationId', None) if chat else None
173
+ if conv_id:
174
+ self.conv_id = conv_id
175
 
176
+ else:
177
+ # grab next chunk and yield it as output
178
+ summary = res.get('summary', None)
179
+ if summary is None:
180
+ continue
181
+ chunk = data['result']['summary']['text']
182
+ accumulated_text += chunk # Append current chunk to accumulation
183
+ if len(accumulated_text) > pattern_max_length:
184
+ accumulated_text = re.sub(r"\[\d+\]", "", accumulated_text)
185
+ accumulated_text = re.sub(r"\s+\.", ".", accumulated_text)
186
+ out_chunk = accumulated_text[:-pattern_max_length]
187
+ chunks.append(out_chunk)
188
+ yield out_chunk
189
+ accumulated_text = accumulated_text[-pattern_max_length:]
190
+
191
+ if summary['done']:
192
+ break
193
+
194
+ # yield the last piece
195
+ if len(accumulated_text) > 0:
196
+ accumulated_text = re.sub(r" \[\d+\]\.", ".", accumulated_text)
197
+ chunks.append(accumulated_text)
198
+ yield accumulated_text
199
+
200
+ return ''.join(chunks)