zolicsaki commited on
Commit
fb0ce27
·
verified ·
1 Parent(s): 7dee66a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+ from contextlib import contextmanager, redirect_stdout
5
+ from io import StringIO
6
+ from typing import Callable, Generator, Optional, List, Dict
7
+ import requests
8
+ import json
9
+ from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH
10
+ import re
11
+ import asyncio
12
+
13
+ import streamlit as st
14
+ import yaml
15
+
16
+ current_dir = os.path.dirname(os.path.abspath(__file__))
17
+ kit_dir = os.path.abspath(os.path.join(current_dir, '..'))
18
+ repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
19
+
20
+ sys.path.append(kit_dir)
21
+ sys.path.append(repo_dir)
22
+
23
+
24
+ from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ GOOGLE_API_KEY = st.secrets["google_api_key"]
28
+ GOOGLE_CX = st.secrets["google_cx"]
29
+
30
+ CONFIG_PATH = os.path.join(current_dir, "config.yaml")
31
+
32
+ USER_AGENTS = [
33
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
34
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15",
35
+ ]
36
+
37
+ def load_config():
38
+ with open(CONFIG_PATH, 'r') as yaml_file:
39
+ return yaml.safe_load(yaml_file)
40
+
41
+
42
+ config = load_config()
43
+ prod_mode = config.get('prod_mode', False)
44
+ additional_env_vars = config.get('additional_env_vars', None)
45
+
46
+ @contextmanager
47
+ def st_capture(output_func: Callable[[str], None]) -> Generator:
48
+ """
49
+ context manager to catch stdout and send it to an output streamlit element
50
+ Args:
51
+ output_func (function to write terminal output in
52
+ Yields:
53
+ Generator:
54
+ """
55
+ with StringIO() as stdout, redirect_stdout(stdout):
56
+ old_write = stdout.write
57
+
58
+ def new_write(string: str) -> int:
59
+ ret = old_write(string)
60
+ output_func(stdout.getvalue())
61
+ return ret
62
+
63
+ stdout.write = new_write # type: ignore
64
+ yield
65
+
66
+ async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=2):
67
+ # First construct messages
68
+ messages = []
69
+ if system_prompt is not None:
70
+ messages.append({"role": "system", "content": system_prompt})
71
+
72
+ if not ignore_context:
73
+ for ques, ans in zip(
74
+ st.session_state.chat_history[::3],
75
+ st.session_state.chat_history[1::3],
76
+ ):
77
+ messages.append({"role": "user", "content": ques})
78
+ messages.append({"role": "assistant", "content": ans})
79
+ messages.append({"role": "user", "content": query})
80
+
81
+ # Create payloads
82
+ payload = {
83
+ "messages": messages,
84
+ "model": config.get("model")
85
+ }
86
+ if max_tokens_to_generate is not None:
87
+ payload["max_tokens"] = max_tokens_to_generate
88
+ headers = {
89
+ "Authorization": f"Basic {st.session_state.SAMBANOVA_API_KEY}",
90
+ "Content-Type": "application/json"
91
+ }
92
+
93
+ try:
94
+ post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True))
95
+ post_response.raise_for_status()
96
+ except requests.exceptions.HTTPError as e:
97
+ if post_response.status_code in {401, 503, 504, 429}:
98
+ print(f"Attempt failed due to rate limit or gate timeout. Status code: {post_response.status_code}. Trying again in {num_seconds_to_sleep} seconds...")
99
+ await asyncio.sleep(num_seconds_to_sleep)
100
+ return await run_samba_api_inference(query) # Retry the request
101
+ else:
102
+ print(f"Request failed with status code: {post_response.status_code}. Error: {e}")
103
+ return ""
104
+
105
+ response_data = json.loads(post_response.text)
106
+
107
+ return response_data["choices"][0]["message"]["content"]
108
+
109
+ def extract_query(text):
110
+ # Regular expression to capture the query within the quotes
111
+ match = re.search(r'query="(.*?)"', text)
112
+
113
+ # If a match is found, return the query, otherwise return None
114
+ if match:
115
+ return match.group(1)
116
+ return None
117
+
118
+ def extract_text_between_brackets(text):
119
+ # Using regular expressions to find all text between brackets
120
+ matches = re.findall(r'\[(.*?)\]', text)
121
+ return matches
122
+
123
+ def search_with_google(query: str):
124
+ """
125
+ Search with google and return the contexts.
126
+ """
127
+ params = {
128
+ "key": GOOGLE_API_KEY,
129
+ "cx": GOOGLE_CX,
130
+ "q": query,
131
+ "num": 5,
132
+ }
133
+ response = requests.get(
134
+ GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT
135
+ )
136
+
137
+ if not response.ok:
138
+ raise Exception(response.status_code, "Search engine error.")
139
+ json_content = response.json()
140
+
141
+ contexts = json_content["items"][:5]
142
+
143
+ return contexts
144
+
145
+ async def get_related_questions(query, contexts = None):
146
+ if contexts:
147
+ related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format(
148
+ context="\n\n".join([c["snippet"] for c in contexts])
149
+ )
150
+ else:
151
+ # When no search is performed, use a generic prompt
152
+ related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH
153
+
154
+ related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt)
155
+
156
+ try:
157
+ return json.loads(related_questions_raw)
158
+ except:
159
+ try:
160
+ extracted_related_questions = extract_text_between_brackets(related_questions_raw)
161
+ return json.loads(extracted_related_questions)
162
+ except:
163
+ return []
164
+
165
+ def process_citations(response: str, search_result_contexts: List[Dict]) -> str:
166
+ """
167
+ Process citations in the response and replace them with numbered icons.
168
+
169
+ Args:
170
+ response (str): The original response with citations.
171
+ search_result_contexts (List[Dict]): The search results with context information.
172
+
173
+ Returns:
174
+ str: The processed response with numbered icons for citations.
175
+ """
176
+ citations = re.findall(r'\[citation:(\d+)\]', response)
177
+
178
+ for i, citation in enumerate(citations, 1):
179
+ response = response.replace(f'[citation:{citation}]', f'<sup>[{i}]</sup>')
180
+
181
+ return response
182
+
183
+ def generate_citation_links(search_result_contexts: List[Dict]) -> str:
184
+ """
185
+ Generate HTML for citation links.
186
+
187
+ Args:
188
+ search_result_contexts (List[Dict]): The search results with context information.
189
+
190
+ Returns:
191
+ str: HTML string with numbered citation links.
192
+ """
193
+ citation_links = []
194
+ for i, context in enumerate(search_result_contexts, 1):
195
+ title = context.get('title', 'No title')
196
+ link = context.get('link', '#')
197
+ citation_links.append(f'<p>[{i}] <a href="{link}" target="_blank">{title}</a></p>')
198
+
199
+ return ''.join(citation_links)
200
+
201
+
202
+ async def run_auto_search_pipe(query):
203
+ full_context_answer = asyncio.create_task(run_samba_api_inference(query))
204
+ related_questions_no_search = asyncio.create_task(get_related_questions(query))
205
+
206
+ # First call Llama3.1 8B with special system prompt for auto search
207
+ with st.spinner('Checking if web search is needed...'):
208
+ auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100)
209
+
210
+ # If Llama3.1 8B returns a search query then run search pipeline
211
+ if AUTO_SEARCH_KEYWORD in auto_search_result:
212
+ st.session_state.search_performed = True
213
+ # search
214
+ with st.spinner('Searching the internet...'):
215
+ search_result_contexts = search_with_google(extract_query(auto_search_result))
216
+
217
+ # RAG response
218
+ with st.spinner('Generating response based on web search...'):
219
+ rag_system_prompt = RAG_TEMPLATE.format(
220
+ context="\n\n".join(
221
+ [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)]
222
+ )
223
+ )
224
+
225
+ model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt))
226
+ related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts))
227
+ # Process citations and generate links
228
+ citation_links = generate_citation_links(search_result_contexts)
229
+
230
+ model_response_complete = await model_response
231
+ processed_response = process_citations(model_response_complete, search_result_contexts)
232
+ related_questions_complete = await related_questions
233
+
234
+
235
+ return processed_response, citation_links, related_questions_complete
236
+
237
+ # If Llama3.1 8B returns an answer directly, then please query Llama 405B to get the best possible answer
238
+ else:
239
+ st.session_state.search_performed = False
240
+ result = await full_context_answer
241
+ related_questions = await related_questions_no_search
242
+ return result, "", related_questions
243
+
244
+
245
+ def handle_userinput(user_question: Optional[str]) -> None:
246
+ """
247
+ Handle user input and generate a response, also update chat UI in streamlit app
248
+ Args:
249
+ user_question (str): The user's question or input.
250
+ """
251
+ if user_question:
252
+ # Clear any existing related question buttons
253
+ if 'related_questions' in st.session_state:
254
+ del st.session_state.related_questions
255
+
256
+ async def run_search():
257
+ return await run_auto_search_pipe(user_question)
258
+
259
+ response, citation_links, related_questions = asyncio.run(run_search())
260
+ if st.session_state.search_performed:
261
+ search_or_not_text = "🔍 Web search was performed for this query."
262
+ else:
263
+ search_or_not_text = "📚 This response was generated from the model's knowledge."
264
+
265
+ st.session_state.chat_history.append(user_question)
266
+ st.session_state.chat_history.append((response, citation_links))
267
+ st.session_state.chat_history.append(search_or_not_text)
268
+
269
+ # Store related questions in session state
270
+ st.session_state.related_questions = related_questions
271
+
272
+ for ques, ans, search_or_not_text in zip(
273
+ st.session_state.chat_history[::3],
274
+ st.session_state.chat_history[1::3],
275
+ st.session_state.chat_history[2::3],
276
+ ):
277
+ with st.chat_message('user'):
278
+ st.write(f'{ques}')
279
+
280
+ with st.chat_message(
281
+ 'ai',
282
+ avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
283
+ ):
284
+ st.markdown(f'{ans[0]}', unsafe_allow_html=True)
285
+ if ans[1]:
286
+ st.markdown("### Sources", unsafe_allow_html=True)
287
+ st.markdown(ans[1], unsafe_allow_html=True)
288
+ st.info(search_or_not_text)
289
+ if len(st.session_state.related_questions) > 0:
290
+ st.markdown("### Related Questions")
291
+ for question in st.session_state.related_questions:
292
+ if st.button(question):
293
+ setChatInputValue(question)
294
+
295
+ def setChatInputValue(chat_input_value: str) -> None:
296
+ js = f"""
297
+ <script>
298
+ function insertText(dummy_var_to_force_repeat_execution) {{
299
+ var chatInput = parent.document.querySelector('textarea[data-testid="stChatInputTextArea"]');
300
+ var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set;
301
+ nativeInputValueSetter.call(chatInput, "{chat_input_value}");
302
+ var event = new Event('input', {{ bubbles: true}});
303
+ chatInput.dispatchEvent(event);
304
+ }}
305
+ insertText(3);
306
+ </script>
307
+ """
308
+ st.components.v1.html(js)
309
+
310
+ def main() -> None:
311
+ st.set_page_config(
312
+ page_title='Auto Web Search Demo',
313
+ page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png',
314
+ )
315
+
316
+
317
+ initialize_env_variables(prod_mode, additional_env_vars)
318
+
319
+ if 'input_disabled' not in st.session_state:
320
+ st.session_state.input_disabled = True
321
+ if 'chat_history' not in st.session_state:
322
+ st.session_state.chat_history = []
323
+ if 'search_performed' not in st.session_state:
324
+ st.session_state.search_performed = False
325
+ if 'related_questions' not in st.session_state:
326
+ st.session_state.related_questions = []
327
+
328
+ st.title(':orange[SambaNova Cloud] Auto Web Search')
329
+
330
+ with st.sidebar:
331
+ st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)')
332
+
333
+ if not are_credentials_set(additional_env_vars):
334
+ api_key, additional_vars = env_input_fields(additional_env_vars)
335
+ if st.button('Save Credentials'):
336
+ message = save_credentials(api_key, additional_vars, prod_mode)
337
+ st.session_state.input_disabled = False
338
+ st.success(message)
339
+ st.rerun()
340
+
341
+ else:
342
+ st.success('Credentials are set')
343
+ if st.button('Clear Credentials'):
344
+ save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode)
345
+ st.session_state.input_disabled = True
346
+ st.rerun()
347
+
348
+
349
+ if are_credentials_set(additional_env_vars):
350
+ with st.expander('**Example Queries With Search**', expanded=True):
351
+ if st.button('What is the population of Virginia?'):
352
+ setChatInputValue(
353
+ 'What is the population of Virginia?'
354
+ )
355
+ if st.button('SNP 500 moves today'):
356
+ setChatInputValue('SNP 500 moves today')
357
+ if st.button('What is the weather in Palo Alto?'):
358
+ setChatInputValue(
359
+ 'What is the weather in Palo Alto?'
360
+ )
361
+ with st.expander('**Example Queries No Search**', expanded=True):
362
+ if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'):
363
+ setChatInputValue(
364
+ 'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'
365
+ )
366
+ if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'):
367
+ setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.')
368
+
369
+ st.markdown('**Reset chat**')
370
+ st.markdown('**Note:** Resetting the chat will clear all interactions history')
371
+ if st.button('Reset conversation'):
372
+ st.session_state.chat_history = []
373
+ st.session_state.sources_history = []
374
+ if 'related_questions' in st.session_state:
375
+ del st.session_state.related_questions
376
+ st.toast('Interactions reset. The next response will clear the history on the screen')
377
+
378
+ # Add a footer with the GitHub citation
379
+ footer_html = """
380
+ <style>
381
+ .footer {
382
+ position: fixed;
383
+ right: 10px;
384
+ bottom: 10px;
385
+ width: auto;
386
+ background-color: transparent;
387
+ color: grey;
388
+ text-align: right;
389
+ padding: 10px;
390
+ font-size: 16px;
391
+ }
392
+ </style>
393
+ <div class="footer">
394
+ Inspired by: <a href="https://github.com/leptonai/search_with_lepton" target="_blank">search_with_lepton</a>
395
+ </div>
396
+ """
397
+ st.markdown(footer_html, unsafe_allow_html=True)
398
+
399
+ user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput')
400
+ handle_userinput(user_question)
401
+
402
+
403
+
404
+ if __name__ == '__main__':
405
+ main()