|
import os |
|
import streamlit as st |
|
import numpy as np |
|
import uuid |
|
import datetime |
|
from dotenv import load_dotenv |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
from langchain_groq import ChatGroq |
|
from langchain.chains import LLMChain |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
def local_css(): |
|
st.markdown(""" |
|
<style> |
|
.main { background-color: #f9f9fc; font-family: 'Inter', sans-serif; } |
|
.chat-container { max-width: 900px; margin: 0 auto; padding: 1rem; |
|
border-radius: 12px; background-color: white; |
|
box-shadow: 0 2px 10px rgba(0,0,0,0.05); } |
|
.stChatMessage { padding: 0.5rem 0; } |
|
[data-testid="stChatMessageContent"] { border-radius: 18px; padding: 0.8rem 1rem; line-height: 1.5; } |
|
.stChatMessageAvatar { background-color: #1f75fe !important; } |
|
[data-testid="stChatMessageAvatar"][data-testid*="assistant"] { background-color: #10a37f !important; } |
|
[data-testid="stSidebar"] { background-color: #ffffff; border-right: 1px solid #e6e6e6; padding: 1rem; } |
|
.chat-history-item { padding: 10px 15px; margin: 5px 0; border-radius: 8px; |
|
cursor: pointer; transition: background-color 0.2s; overflow: hidden; |
|
text-overflow: ellipsis; white-space: nowrap; } |
|
.chat-history-item:hover { background-color: #f0f0f5; } |
|
.chat-history-active { background-color: #e6f0ff; border-left: 3px solid #1f75fe; } |
|
.stTextInput > div > div > input { border-radius: 20px; padding: 10px 15px; |
|
border: 1px solid #e0e0e0; background-color: #f9f9fc; } |
|
.stButton > button { border-radius: 20px; padding: 0.3rem 1rem; |
|
background-color: #1f75fe; color: white; border: none; transition: all 0.2s; } |
|
.stButton > button:hover { background-color: #0056b3; transform: translateY(-2px); } |
|
.custom-header { display: flex; align-items: center; margin-bottom: 1rem; } |
|
.custom-header h1 { margin: 0; font-size: 1.8rem; color: #333; } |
|
.typing-indicator { display: flex; padding: 10px 15px; |
|
background-color: #f0f0f5; border-radius: 18px; width: fit-content; } |
|
.typing-indicator span { height: 8px; width: 8px; margin: 0 1px; |
|
background-color: #a0a0a0; border-radius: 50%; display: inline-block; |
|
animation: typing 1.4s infinite ease-in-out both; } |
|
.typing-indicator span:nth-child(1){animation-delay:0s;} |
|
.typing-indicator span:nth-child(2){animation-delay:0.2s;} |
|
.typing-indicator span:nth-child(3){animation-delay:0.4s;} |
|
@keyframes typing{0%{transform:scale(1);}50%{transform:scale(1.5);}100%{transform:scale(1);}} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
def init_session_state(): |
|
if 'messages' not in st.session_state: |
|
st.session_state.messages = [] |
|
if 'chat_sessions' not in st.session_state: |
|
st.session_state.chat_sessions = {} |
|
if 'current_session_id' not in st.session_state: |
|
st.session_state.current_session_id = str(uuid.uuid4()) |
|
if 'session_name' not in st.session_state: |
|
st.session_state.session_name = f"Chat {datetime.datetime.now().strftime('%b %d, %H:%M')}" |
|
|
|
def save_chat_session(): |
|
if st.session_state.current_session_id: |
|
st.session_state.chat_sessions[st.session_state.current_session_id] = { |
|
"name": st.session_state.session_name, |
|
"messages": st.session_state.messages, |
|
"timestamp": datetime.datetime.now().isoformat() |
|
} |
|
|
|
def load_chat_session(session_id): |
|
if session_id in st.session_state.chat_sessions: |
|
st.session_state.current_session_id = session_id |
|
st.session_state.messages = st.session_state.chat_sessions[session_id]["messages"] |
|
st.session_state.session_name = st.session_state.chat_sessions[session_id]["name"] |
|
|
|
def create_new_chat(): |
|
st.session_state.current_session_id = str(uuid.uuid4()) |
|
st.session_state.messages = [] |
|
st.session_state.session_name = f"Chat {datetime.datetime.now().strftime('%b %d, %H:%M')}" |
|
|
|
|
|
def setup_models(groq_api_key): |
|
llm = ChatGroq( |
|
model="llama-3.3-70b-versatile", |
|
groq_api_key=groq_api_key |
|
) |
|
|
|
direct_prompt = PromptTemplate( |
|
input_variables=["question"], |
|
template=""" |
|
Answer the question in detailed form. |
|
|
|
Question: {question} |
|
Answer: |
|
""" |
|
) |
|
direct_chain = LLMChain(llm=llm, prompt=direct_prompt) |
|
|
|
search_prompt = PromptTemplate( |
|
input_variables=["web_results", "question"], |
|
template=""" |
|
Use these web search results to give a comprehensive answer: |
|
|
|
Search Results: |
|
{web_results} |
|
|
|
Question: {question} |
|
Answer: |
|
""" |
|
) |
|
search_chain = LLMChain(llm=llm, prompt=search_prompt) |
|
|
|
return direct_chain, search_chain, llm |
|
|
|
def decide_search(query: str, llm) -> tuple[bool, str | None]: |
|
decision_prompt = PromptTemplate( |
|
input_variables=["query"], |
|
template=""" |
|
You are a decision assistant. If the user's question needs up-to-date |
|
information from the web, respond with "SEARCH: <best keywords>". |
|
Otherwise respond with "NO_SEARCH". Do not add anything else. |
|
|
|
Question: {query} |
|
""" |
|
) |
|
decision_chain = LLMChain(llm=llm, prompt=decision_prompt) |
|
response = decision_chain.run({"query": query}).strip() |
|
if response.upper().startswith("SEARCH:"): |
|
return True, response[len("SEARCH:"):].strip() |
|
return False, None |
|
|
|
@st.cache_data |
|
def perform_search(keywords: str) -> str: |
|
return DuckDuckGoSearchRun().run(keywords) |
|
|
|
|
|
def main(): |
|
st.set_page_config( |
|
page_title="General Knowledge Assistant", |
|
page_icon="π§", |
|
layout="wide", |
|
initial_sidebar_state="expanded" |
|
) |
|
|
|
local_css() |
|
init_session_state() |
|
|
|
with st.sidebar: |
|
st.markdown("<h2 style='text-align: center;'>π§ Knowledge Assistant</h2>", unsafe_allow_html=True) |
|
|
|
st.subheader("π API Key") |
|
groq_api_key = os.environ.get("GROQ_API_KEY") or st.text_input("Groq API Key", type="password") |
|
|
|
if not groq_api_key: |
|
st.warning("Please provide the Groq API key to proceed.") |
|
st.stop() |
|
|
|
st.subheader("π¬ Chat History") |
|
|
|
if st.button("β New Chat", key="new_chat"): |
|
create_new_chat() |
|
|
|
new_name = st.text_input("Chat Name", value=st.session_state.session_name) |
|
if new_name != st.session_state.session_name: |
|
st.session_state.session_name = new_name |
|
save_chat_session() |
|
|
|
st.markdown("#### Previous Chats") |
|
sorted_sessions = sorted( |
|
st.session_state.chat_sessions.items(), |
|
key=lambda x: x[1].get("timestamp", ""), |
|
reverse=True |
|
) |
|
for session_id, session in sorted_sessions: |
|
preview = "New conversation" |
|
if session["messages"]: |
|
first_msg = session["messages"][0] |
|
if isinstance(first_msg, dict) and "content" in first_msg: |
|
preview = first_msg["content"] |
|
if len(preview) > 30: |
|
preview = preview[:30] + "..." |
|
style = "chat-history-item chat-history-active" if session_id == st.session_state.current_session_id else "chat-history-item" |
|
|
|
col1, col2 = st.columns([0.8, 0.2]) |
|
with col1: |
|
if st.button(session["name"], key=f"load_session_{session_id}"): |
|
load_chat_session(session_id) |
|
st.rerun() |
|
with col2: |
|
if st.button("ποΈ", key=f"delete_{session_id}", help="Delete this chat"): |
|
if session_id in st.session_state.chat_sessions: |
|
del st.session_state.chat_sessions[session_id] |
|
if session_id == st.session_state.current_session_id: |
|
create_new_chat() |
|
st.rerun() |
|
|
|
direct_chain, search_chain, llm = setup_models(groq_api_key) |
|
|
|
st.markdown(""" |
|
<div class="custom-header"> |
|
<h1>π§ General Knowledge Assistant</h1> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
chat_container = st.container() |
|
user_input = st.chat_input("Ask me anything...") |
|
|
|
if user_input: |
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
save_chat_session() |
|
|
|
with chat_container: |
|
typing_placeholder = st.empty() |
|
typing_placeholder.markdown(""" |
|
<div class="typing-indicator"> |
|
<span></span><span></span><span></span> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
try: |
|
needs_search, terms = decide_search(user_input, llm) |
|
if needs_search and terms: |
|
web_results = perform_search(terms) |
|
answer = search_chain.run({"web_results": web_results, "question": user_input}) |
|
else: |
|
answer = direct_chain.run({"question": user_input}) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": answer}) |
|
save_chat_session() |
|
|
|
except Exception as e: |
|
err = f"Sorry, I encountered an error: {str(e)}" |
|
st.session_state.messages.append({"role": "assistant", "content": err}) |
|
save_chat_session() |
|
|
|
typing_placeholder.empty() |
|
st.rerun() |
|
|
|
with chat_container: |
|
if not st.session_state.messages: |
|
st.markdown(""" |
|
<div style="text-align: center; padding: 50px 20px;"> |
|
<h3>π Welcome to the General Knowledge Assistant!</h3> |
|
<p>Ask me anything about general knowledge, facts, or concepts.</p> |
|
<p>I can search the web when needed to provide you with up-to-date information.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
else: |
|
for msg in st.session_state.messages: |
|
if isinstance(msg, dict) and "role" in msg and "content" in msg: |
|
with st.chat_message(msg["role"]): |
|
st.write(msg["content"]) |
|
else: |
|
st.error(f"Invalid message format: {msg}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|