import logging import os import sys from contextlib import contextmanager, redirect_stdout from io import StringIO from typing import Callable, Generator, Optional, List, Dict import requests import json from consts import AUTO_SEARCH_KEYWORD, SEARCH_TOOL_INSTRUCTION, RELATED_QUESTIONS_TEMPLATE_SEARCH, SEARCH_TOOL_INSTRUCTION, RAG_TEMPLATE, GOOGLE_SEARCH_ENDPOINT, DEFAULT_SEARCH_ENGINE_TIMEOUT, RELATED_QUESTIONS_TEMPLATE_NO_SEARCH import re import asyncio import random import streamlit as st import yaml current_dir = os.path.dirname(os.path.abspath(__file__)) kit_dir = os.path.abspath(os.path.join(current_dir, '..')) repo_dir = os.path.abspath(os.path.join(kit_dir, '..')) sys.path.append(kit_dir) sys.path.append(repo_dir) from visual_env_utils import are_credentials_set, env_input_fields, initialize_env_variables, save_credentials logging.basicConfig(level=logging.INFO) GOOGLE_API_KEY = st.secrets["google_api_key"] GOOGLE_CX = st.secrets["google_cx"] BACKUP_KEYS = [st.secrets["backup_key_1"], st.secrets["backup_key_2"], st.secrets["backup_key_3"], st.secrets["backup_key_4"], st.secrets["backup_key_5"]] CONFIG_PATH = os.path.join(current_dir, "config.yaml") USER_AGENTS = [ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15", ] def load_config(): with open(CONFIG_PATH, 'r') as yaml_file: return yaml.safe_load(yaml_file) config = load_config() prod_mode = config.get('prod_mode', False) additional_env_vars = config.get('additional_env_vars', None) @contextmanager def st_capture(output_func: Callable[[str], None]) -> Generator: """ context manager to catch stdout and send it to an output streamlit element Args: output_func (function to write terminal output in Yields: Generator: """ with StringIO() as stdout, redirect_stdout(stdout): old_write = stdout.write def new_write(string: str) -> int: ret = old_write(string) output_func(stdout.getvalue()) return ret stdout.write = new_write # type: ignore yield async def run_samba_api_inference(query, system_prompt = None, ignore_context=False, max_tokens_to_generate=None, num_seconds_to_sleep=1, over_ride_key=None): # First construct messages messages = [] if system_prompt is not None: messages.append({"role": "system", "content": system_prompt}) if not ignore_context: for ques, ans in zip( st.session_state.chat_history[::3], st.session_state.chat_history[1::3], ): messages.append({"role": "user", "content": ques}) messages.append({"role": "assistant", "content": ans}) messages.append({"role": "user", "content": query}) # Create payloads payload = { "messages": messages, "model": config.get("model") } if max_tokens_to_generate is not None: payload["max_tokens"] = max_tokens_to_generate if over_ride_key is None: api_key = st.session_state.SAMBANOVA_API_KEY else: api_key = over_ride_key headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } try: post_response = await asyncio.get_event_loop().run_in_executor(None, lambda: requests.post(config.get("url"), json=payload, headers=headers, stream=True)) post_response.raise_for_status() except requests.exceptions.HTTPError as e: if post_response.status_code in {401, 503}: st.info(f"Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/.") return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." if post_response.status_code in {429, 504}: await asyncio.sleep(num_seconds_to_sleep) return await run_samba_api_inference(query, over_ride_key=random.choice(BACKUP_KEYS)) # Retry the request else: print(f"Request failed with status code: {post_response.status_code}. Error: {e}") return "Invalid Key! Please make sure you have a valid SambaCloud key from https://cloud.sambanova.ai/." response_data = json.loads(post_response.text) return response_data["choices"][0]["message"]["content"] def extract_query(text): # Regular expression to capture the query within the quotes match = re.search(r'query="(.*?)"', text) # If a match is found, return the query, otherwise return None if match: return match.group(1) return None def extract_text_between_brackets(text): # Using regular expressions to find all text between brackets matches = re.findall(r'\[(.*?)\]', text) return matches def search_with_google(query: str): """ Search with google and return the contexts. """ params = { "key": GOOGLE_API_KEY, "cx": GOOGLE_CX, "q": query, "num": 5, } response = requests.get( GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT ) if not response.ok: raise Exception(response.status_code, "Search engine error.") json_content = response.json() contexts = json_content["items"][:5] return contexts async def get_related_questions(query, contexts = None): if contexts: related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH.format( context="\n\n".join([c["snippet"] for c in contexts]) ) else: # When no search is performed, use a generic prompt related_question_system_prompt = RELATED_QUESTIONS_TEMPLATE_SEARCH related_questions_raw = await run_samba_api_inference(query, related_question_system_prompt) try: return json.loads(related_questions_raw) except: try: extracted_related_questions = extract_text_between_brackets(related_questions_raw) return json.loads(extracted_related_questions) except: return [] def process_citations(response: str, search_result_contexts: List[Dict]) -> str: """ Process citations in the response and replace them with numbered icons. Args: response (str): The original response with citations. search_result_contexts (List[Dict]): The search results with context information. Returns: str: The processed response with numbered icons for citations. """ citations = re.findall(r'\[citation:(\d+)\]', response) for i, citation in enumerate(citations, 1): response = response.replace(f'[citation:{citation}]', f'[{i}]') return response def generate_citation_links(search_result_contexts: List[Dict]) -> str: """ Generate HTML for citation links. Args: search_result_contexts (List[Dict]): The search results with context information. Returns: str: HTML string with numbered citation links. """ citation_links = [] for i, context in enumerate(search_result_contexts, 1): title = context.get('title', 'No title') link = context.get('link', '#') citation_links.append(f'

[{i}] {title}

') return ''.join(citation_links) async def run_auto_search_pipe(query): full_context_answer = asyncio.create_task(run_samba_api_inference(query)) related_questions_no_search = asyncio.create_task(get_related_questions(query)) # First call Llama3.1 8B with special system prompt for auto search with st.spinner('Checking if web search is needed...'): auto_search_result = await run_samba_api_inference(query, SEARCH_TOOL_INSTRUCTION, True, max_tokens_to_generate=100) # If Llama3.1 8B returns a search query then run search pipeline if AUTO_SEARCH_KEYWORD in auto_search_result: st.session_state.search_performed = True # search with st.spinner('Searching the internet...'): search_result_contexts = search_with_google(extract_query(auto_search_result)) # RAG response with st.spinner('Generating response based on web search...'): rag_system_prompt = RAG_TEMPLATE.format( context="\n\n".join( [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(search_result_contexts)] ) ) model_response = asyncio.create_task(run_samba_api_inference(query, rag_system_prompt)) related_questions = asyncio.create_task(get_related_questions(query, search_result_contexts)) # Process citations and generate links citation_links = generate_citation_links(search_result_contexts) model_response_complete = await model_response processed_response = process_citations(model_response_complete, search_result_contexts) related_questions_complete = await related_questions return processed_response, citation_links, related_questions_complete # If Llama3.1 8B returns an answer directly, then please query Llama 405B to get the best possible answer else: st.session_state.search_performed = False result = await full_context_answer related_questions = await related_questions_no_search return result, "", related_questions def handle_userinput(user_question: Optional[str]) -> None: """ Handle user input and generate a response, also update chat UI in streamlit app Args: user_question (str): The user's question or input. """ if user_question: # Clear any existing related question buttons if 'related_questions' in st.session_state: st.session_state.related_questions = [] async def run_search(): return await run_auto_search_pipe(user_question) response, citation_links, related_questions = asyncio.run(run_search()) if st.session_state.search_performed: search_or_not_text = "🔍 Web search was performed for this query." else: search_or_not_text = "📚 This response was generated from the model's knowledge." st.session_state.chat_history.append(user_question) st.session_state.chat_history.append((response, citation_links)) st.session_state.chat_history.append(search_or_not_text) # Store related questions in session state st.session_state.related_questions = related_questions for ques, ans, search_or_not_text in zip( st.session_state.chat_history[::3], st.session_state.chat_history[1::3], st.session_state.chat_history[2::3], ): with st.chat_message('user'): st.write(f'{ques}') with st.chat_message( 'ai', avatar='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', ): st.markdown(f'{ans[0]}', unsafe_allow_html=True) if ans[1]: st.markdown("### Sources", unsafe_allow_html=True) st.markdown(ans[1], unsafe_allow_html=True) st.info(search_or_not_text) if len(st.session_state.related_questions) > 0: st.markdown("### Related Questions") for question in st.session_state.related_questions: if st.button(question): setChatInputValue(question) def setChatInputValue(chat_input_value: str) -> None: js = f""" """ st.components.v1.html(js) def main() -> None: st.set_page_config( page_title='Auto Web Search Demo', page_icon='https://sambanova.ai/hubfs/logotype_sambanova_orange.png', ) initialize_env_variables(prod_mode, additional_env_vars) if 'input_disabled' not in st.session_state: if 'SAMBANOVA_API_KEY' in st.session_state: st.session_state.input_disabled = False else: st.session_state.input_disabled = True if 'chat_history' not in st.session_state: st.session_state.chat_history = [] if 'search_performed' not in st.session_state: st.session_state.search_performed = False if 'related_questions' not in st.session_state: st.session_state.related_questions = [] st.title(' Auto Web Search') st.subheader('Powered by :orange[SambaNova Cloud] and Llama405B') with st.sidebar: st.title('Get your :orange[SambaNova Cloud] API key [here](https://cloud.sambanova.ai/apis)') if not are_credentials_set(additional_env_vars): api_key, additional_vars = env_input_fields(additional_env_vars) if st.button('Save Credentials'): message = save_credentials(api_key, additional_vars, prod_mode) st.session_state.input_disabled = False st.success(message) st.rerun() else: st.success('Credentials are set') if st.button('Clear Credentials'): save_credentials('', {var: '' for var in (additional_env_vars or [])}, prod_mode) st.session_state.input_disabled = True st.rerun() if are_credentials_set(additional_env_vars): with st.expander('**Example Queries With Search**', expanded=True): if st.button('What is the population of Virginia?'): setChatInputValue( 'What is the population of Virginia?' ) if st.button('SNP 500 stock market moves'): setChatInputValue('SNP 500 stock market moves') if st.button('What is the weather in Palo Alto?'): setChatInputValue( 'What is the weather in Palo Alto?' ) with st.expander('**Example Queries No Search**', expanded=True): if st.button('write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.'): setChatInputValue( 'write a short poem following a specific pattern: the first letter of every word should spell out the name of a country.' ) if st.button('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.'): setChatInputValue('Write a python program to find the longest root to leaf path in a tree, and some test cases for it.') st.markdown('**Reset chat**') st.markdown('**Note:** Resetting the chat will clear all interactions history') if st.button('Reset conversation'): st.session_state.chat_history = [] st.session_state.sources_history = [] if 'related_questions' in st.session_state: st.session_state.related_questions = [] st.toast('Interactions reset. The next response will clear the history on the screen') # Add a footer with the GitHub citation footer_html = """ """ st.markdown(footer_html, unsafe_allow_html=True) user_question = st.chat_input('Ask something', disabled=st.session_state.input_disabled, key='TheChatInput') handle_userinput(user_question) if __name__ == '__main__': main()