Spaces:
Sleeping
Sleeping
import logging | |
import os | |
import sys | |
from contextlib import contextmanager, redirect_stdout | |
from io import StringIO | |
from typing import Callable, Generator, Optional, List, Dict | |
import requests | |
import json | |
from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH | |
import re | |
import asyncio | |
import random | |
import streamlit as st | |
import yaml | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) | |
repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) | |
sys.path.append(kit_dir) | |
sys.path.append(repo_dir) | |
from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials | |
logging.basicConfig(level=logging.INFO) | |
GOOGLE_API_KEY = st.secrets["google_api_key"] | |
GOOGLE_CX = st.secrets["google_cx"] | |
BACKUP_KEYS = [st.secrets["backup_key_1"], st.secrets["backup_key_2"], st.secrets["backup_key_3"], st.secrets["backup_key_4"], st.secrets["backup_key_5"]] | |
CONFIG_PATH = os.path.join(current_dir, "config.yaml") | |
USER_AGENTS = [ | |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", | |
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15", | |
] | |
def load_config(): | |
with open(CONFIG_PATH, 'r') as yaml_file: | |
return yaml.safe_load(yaml_file) | |
config = load_config() | |
prod_mode = config.get('prod_mode', False) | |
additional_env_vars = config.get('additional_env_vars', None) | |
def st_capture(output_func: Callable[[str], None]) -> Generator: | |
""" | |
context manager to catch stdout and send it to an output streamlit element | |
Args: | |
output_func (function to write terminal output in | |
Yields: | |
Generator: | |
""" | |
with StringIO() as stdout, redirect_stdout(stdout): | |
old_write = stdout.write | |
def new_write(string: str) -> int: | |
ret = old_write(string) | |
output_func(stdout.getvalue()) | |
return ret | |
stdout.write = new_write # type: ignore | |
yield | |
async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=1, over_ride_key=None): | |
# First construct messages | |
messages = [] | |
if system_prompt is not None: | |
messages.append({"role": "system", "content": system_prompt}) | |
if not ignore_context: | |
for ques, ans in zip( | |
st.session_state.chat_history[::3], | |
st.session_state.chat_history[1::3], | |
): | |
messages.append({"role": "user", "content": ques}) | |
messages.append({"role": "assistant", "content": ans}) | |
messages.append({"role": "user", "content": query}) | |
# Create payloads | |
payload = { | |
"messages": messages, | |
"model": config.get("model") | |
} | |
if max_tokens_to_generate is not None: | |
payload["max_tokens"] = max_tokens_to_generate | |
if over_ride_key is None: | |
api_key = st.session_state.SAMBANOVA_API_KEY | |
else: | |
api_key = over_ride_key | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
try: | |
post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True)) | |
post_response.raise_for_status() | |
except requests.exceptions.HTTPError as e: | |
if post_response.status_code in {401, 503}: | |
st.info(f"Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/.") | |
return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." | |
if post_response.status_code in {429, 504}: | |
await asyncio.sleep(num_seconds_to_sleep) | |
return await run_samba_api_inference(query, over_ride_key=random.choice(BACKUP_KEYS)) # Retry the request | |
else: | |
print(f"Request failed with status code: {post_response.status_code}. Error: {e}") | |
return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." | |
response_data = json.loads(post_response.text) | |
return response_data["choices"][0]["message"]["content"] | |
def extract_query(text): | |
# Regular expression to capture the query within the quotes | |
match = re.search(r'query="(.*?)"', text) | |
# If a match is found, return the query, otherwise return None | |
if match: | |
return match.group(1) | |
return None | |
def extract_text_between_brackets(text): | |
# Using regular expressions to find all text between brackets | |
matches = re.findall(r'\[(.*?)\]', text) | |
return matches | |
def search_with_google(query: str): | |
""" | |
Search with google and return the contexts. | |
""" | |
params = { | |
"key": GOOGLE_API_KEY, | |
"cx": GOOGLE_CX, | |
"q": query, | |
"num": 5, | |
} | |
response = requests.get( | |
GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT | |
) | |
if not response.ok: | |
raise Exception(response.status_code, "Search engine error.") | |
json_content = response.json() | |
contexts = json_content["items"][:5] | |
return contexts | |
async def get_related_questions(query, contexts = None): | |
if contexts: | |
related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format( | |
context="\n\n".join([c["snippet"] for c in contexts]) | |
) | |
else: | |
# When no search is performed, use a generic prompt | |
related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH | |
related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt) | |
try: | |
return json.loads(related_questions_raw) | |
except: | |
try: | |
extracted_related_questions = extract_text_between_brackets(related_questions_raw) | |
return json.loads(extracted_related_questions) | |
except: | |
return [] | |
def process_citations(response: str, search_result_contexts: List[Dict]) -> str: | |
""" | |
Process citations in the response and replace them with numbered icons. | |
Args: | |
response (str): The original response with citations. | |
search_result_contexts (List[Dict]): The search results with context information. | |
Returns: | |
str: The processed response with numbered icons for citations. | |
""" | |
citations = re.findall(r'\[citation:(\d+)\]', response) | |
for i, citation in enumerate(citations, 1): | |
response = response.replace(f'[citation:{citation}]', f'<sup>[{i}]</sup>') | |
return response | |
def generate_citation_links(search_result_contexts: List[Dict]) -> str: | |
""" | |
Generate HTML for citation links. | |
Args: | |
search_result_contexts (List[Dict]): The search results with context information. | |
Returns: | |
str: HTML string with numbered citation links. | |
""" | |
citation_links = [] | |
for i, context in enumerate(search_result_contexts, 1): | |
title = context.get('title', 'No title') | |
link = context.get('link', '#') | |
citation_links.append(f'<p>[{i}] <a href="{link}" target="_blank">{title}</a></p>') | |
return ''.join(citation_links) | |
async def run_auto_search_pipe(query): | |
full_context_answer = asyncio.create_task(run_samba_api_inference(query)) | |
related_questions_no_search = asyncio.create_task(get_related_questions(query)) | |
# First call Llama3.1 8B with special system prompt for auto search | |
with st.spinner('Checking if web search is needed...'): | |
auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100) | |
# If Llama3.1 8B returns a search query then run search pipeline | |
if AUTO_SEARCH_KEYWORD in auto_search_result: | |
st.session_state.search_performed = True | |
# search | |
with st.spinner('Searching the internet...'): | |
search_result_contexts = search_with_google(extract_query(auto_search_result)) | |
# RAG response | |
with st.spinner('Generating response based on web search...'): | |
rag_system_prompt = RAG_TEMPLATE.format( | |
context="\n\n".join( | |
[f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)] | |
) | |
) | |
model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt)) | |
related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts)) | |
# Process citations and generate links | |
citation_links = generate_citation_links(search_result_contexts) | |
model_response_complete = await model_response | |
processed_response = process_citations(model_response_complete, search_result_contexts) | |
related_questions_complete = await related_questions | |
return processed_response, citation_links, related_questions_complete | |
# If Llama3.1 8B returns an answer directly, then please query Llama 405B to get the best possible answer | |
else: | |
st.session_state.search_performed = False | |
result = await full_context_answer | |
related_questions = await related_questions_no_search | |
return result, "", related_questions | |
def handle_userinput(user_question: Optional[str]) -> None: | |
""" | |
Handle user input and generate a response, also update chat UI in streamlit app | |
Args: | |
user_question (str): The user's question or input. | |
""" | |
if user_question: | |
# Clear any existing related question buttons | |
if 'related_questions' in st.session_state: | |
st.session_state.related_questions = [] | |
async def run_search(): | |
return await run_auto_search_pipe(user_question) | |
response, citation_links, related_questions = asyncio.run(run_search()) | |
if st.session_state.search_performed: | |
search_or_not_text = "🔍 Web search was performed for this query." | |
else: | |
search_or_not_text = "📚 This response was generated from the model's knowledge." | |
st.session_state.chat_history.append(user_question) | |
st.session_state.chat_history.append((response, citation_links)) | |
st.session_state.chat_history.append(search_or_not_text) | |
# Store related questions in session state | |
st.session_state.related_questions = related_questions | |
for ques, ans, search_or_not_text in zip( | |
st.session_state.chat_history[::3], | |
st.session_state.chat_history[1::3], | |
st.session_state.chat_history[2::3], | |
): | |
with st.chat_message('user'): | |
st.write(f'{ques}') | |
with st.chat_message( | |
'ai', | |
avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', | |
): | |
st.markdown(f'{ans[0]}', unsafe_allow_html=True) | |
if ans[1]: | |
st.markdown("### Sources", unsafe_allow_html=True) | |
st.markdown(ans[1], unsafe_allow_html=True) | |
st.info(search_or_not_text) | |
if len(st.session_state.related_questions) > 0: | |
st.markdown("### Related Questions") | |
for question in st.session_state.related_questions: | |
if st.button(question): | |
setChatInputValue(question) | |
def setChatInputValue(chat_input_value: str) -> None: | |
js = f""" | |
<script> | |
function insertText(dummy_var_to_force_repeat_execution) {{ | |
var chatInput = parent.document.querySelector('textarea[data-testid="stChatInputTextArea"]'); | |
var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set; | |
nativeInputValueSetter.call(chatInput, "{chat_input_value}"); | |
var event = new Event('input', {{ bubbles: true}}); | |
chatInput.dispatchEvent(event); | |
}} | |
insertText(3); | |
</script> | |
""" | |
st.components.v1.html(js) | |
def main() -> None: | |
st.set_page_config( | |
page_title='Auto Web Search Demo', | |
page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', | |
) | |
initialize_env_variables(prod_mode, additional_env_vars) | |
if 'input_disabled' not in st.session_state: | |
if 'SAMBANOVA_API_KEY' in st.session_state: | |
st.session_state.input_disabled = False | |
else: | |
st.session_state.input_disabled = True | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
if 'search_performed' not in st.session_state: | |
st.session_state.search_performed = False | |
if 'related_questions' not in st.session_state: | |
st.session_state.related_questions = [] | |
st.title(' Auto Web Search') | |
st.subheader('Powered by :orange[SambaNova Cloud] and Llama405B') | |
with st.sidebar: | |
st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)') | |
if not are_credentials_set(additional_env_vars): | |
api_key, additional_vars = env_input_fields(additional_env_vars) | |
if st.button('Save Credentials'): | |
message = save_credentials(api_key, additional_vars, prod_mode) | |
st.session_state.input_disabled = False | |
st.success(message) | |
st.rerun() | |
else: | |
st.success('Credentials are set') | |
if st.button('Clear Credentials'): | |
save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode) | |
st.session_state.input_disabled = True | |
st.rerun() | |
if are_credentials_set(additional_env_vars): | |
with st.expander('**Example Queries With Search**', expanded=True): | |
if st.button('What is the population of Virginia?'): | |
setChatInputValue( | |
'What is the population of Virginia?' | |
) | |
if st.button('SNP 500 stock market moves'): | |
setChatInputValue('SNP 500 stock market moves') | |
if st.button('What is the weather in Palo Alto?'): | |
setChatInputValue( | |
'What is the weather in Palo Alto?' | |
) | |
with st.expander('**Example Queries No Search**', expanded=True): | |
if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'): | |
setChatInputValue( | |
'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.' | |
) | |
if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'): | |
setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.') | |
st.markdown('**Reset chat**') | |
st.markdown('**Note:** Resetting the chat will clear all interactions history') | |
if st.button('Reset conversation'): | |
st.session_state.chat_history = [] | |
st.session_state.sources_history = [] | |
if 'related_questions' in st.session_state: | |
st.session_state.related_questions = [] | |
st.toast('Interactions reset. The next response will clear the history on the screen') | |
# Add a footer with the GitHub citation | |
footer_html = """ | |
<style> | |
.footer { | |
position: fixed; | |
right: 10px; | |
bottom: 10px; | |
width: auto; | |
background-color: transparent; | |
color: grey; | |
text-align: right; | |
padding: 10px; | |
font-size: 16px; | |
} | |
</style> | |
<div class="footer"> | |
Inspired by: <a href="https://github.com/leptonai/search_with_lepton" target="_blank">search_with_lepton</a> | |
</div> | |
""" | |
st.markdown(footer_html, unsafe_allow_html=True) | |
user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput') | |
handle_userinput(user_question) | |
if __name__ == '__main__': | |
main() |