import streamlit as st |
import pdfplumber |
import docx |
import os |
import re |
import numpy as np |
import google.generativeai as palm |
from sklearn.metrics.pairwise import cosine_similarity |
import logging |
import time |
import uuid |
import json |
import firebase_admin |
from firebase_admin import credentials, firestore |
from dotenv import load_dotenv |
import chromadb |
logging.basicConfig( |
level=logging.INFO, |
format='%(asctime)s - %(levelname)s - %(message)s', |
handlers=[logging.StreamHandler()] |
) |
logger = logging.getLogger(__name__) |
load_dotenv() |
class Config: |
EMBEDDING_MODEL = "models/text-embedding-004" |
TOP_N = 5 |
"You are a helpful assistant. Answer the question using the provided context below. " |
"Answer based on your knowledge if the context given is not enough." |
) |
GENERATION_MODEL = "models/gemini-1.5-flash" |
def init_firebase(): |
"""Initialize Firebase with proper credential handling""" |
if not firebase_admin._apps: |
try: |
firebase_cred = os.getenv("FIREBASE_CRED") |
if not firebase_cred: |
logger.error("Firebase credentials not found in environment variables") |
st.error("Firebase configuration is missing. Please check your .env file.") |
st.stop() |
cred_dict = json.loads(firebase_cred) |
cred = credentials.Certificate(cred_dict) |
firebase_admin.initialize_app(cred) |
logger.info("Firebase initialized successfully") |
except json.JSONDecodeError: |
logger.error("Invalid Firebase credentials format") |
st.error("Firebase credentials are invalid. Please check your .env file.") |
st.stop() |
except Exception as e: |
logger.error("Firebase initialization failed", exc_info=True) |
st.error("Failed to initialize Firebase. Please contact support.") |
st.stop() |
def init_chroma(): |
"""Initialize ChromaDB with proper persistence handling""" |
try: |
persist_directory = "chroma_db" |
os.makedirs(persist_directory, exist_ok=True) |
client = chromadb.PersistentClient(path=persist_directory) |
collection = client.get_or_create_collection( |
name="document_embeddings", |
metadata={"hnsw:space": "cosine"} |
) |
logger.info("ChromaDB initialized successfully") |
return client, collection |
except Exception as e: |
logger.error("ChromaDB initialization failed", exc_info=True) |
st.error("Failed to initialize ChromaDB. Please check your configuration.") |
st.stop() |
init_firebase() |
fs_client = firestore.client() |
chroma_client, embedding_collection = init_chroma() |
API_KEY = os.getenv("GOOGLE_API_KEY") |
if not API_KEY: |
st.error("Google API key is not configured.") |
st.stop() |
palm.configure(api_key=API_KEY) |
@st.cache_data(show_spinner=True) |
def generate_embedding_cached(text: str) -> list: |
"""Generate embeddings with caching""" |
logger.info(f"Generating embedding for text: {text[:50]}...") |
try: |
response = palm.embed_content( |
model=Config.EMBEDDING_MODEL, |
content=text, |
task_type="retrieval_document" |
) |
if "embedding" not in response or not response["embedding"]: |
logger.error("No embedding returned from API") |
return [0.0] * 768 |
embedding = np.array(response["embedding"]) |
if embedding.ndim == 2: |
embedding = embedding.flatten() |
return embedding.tolist() |
except Exception as e: |
logger.error(f"Embedding generation failed: {e}") |
return [0.0] * 768 |
def extract_text_from_file(uploaded_file) -> str: |
"""Extract text from various file formats""" |
file_name = uploaded_file.name.lower() |
if file_name.endswith(".txt"): |
return uploaded_file.read().decode("utf-8") |
elif file_name.endswith(".pdf"): |
with pdfplumber.open(uploaded_file) as pdf: |
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()]) |
elif file_name.endswith(".docx"): |
doc = docx.Document(uploaded_file) |
return "\n".join([para.text for para in doc.paragraphs]) |
else: |
raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.") |
def chunk_text(text: str) -> list[str]: |
"""Split text into manageable chunks""" |
max_words = Config.CHUNK_WORDS |
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] |
chunks = [] |
current_chunk = "" |
current_word_count = 0 |
for paragraph in paragraphs: |
para_word_count = len(paragraph.split()) |
if para_word_count > max_words: |
if current_chunk: |
chunks.append(current_chunk.strip()) |
current_chunk = "" |
current_word_count = 0 |
sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
temp_chunk = "" |
temp_word_count = 0 |
for sentence in sentences: |
sentence_word_count = len(sentence.split()) |
if temp_word_count + sentence_word_count > max_words: |
if temp_chunk: |
chunks.append(temp_chunk.strip()) |
temp_chunk = sentence + " " |
temp_word_count = sentence_word_count |
else: |
temp_chunk += sentence + " " |
temp_word_count += sentence_word_count |
if temp_chunk: |
chunks.append(temp_chunk.strip()) |
else: |
if current_word_count + para_word_count > max_words: |
if current_chunk: |
chunks.append(current_chunk.strip()) |
current_chunk = paragraph + "\n\n" |
current_word_count = para_word_count |
else: |
current_chunk += paragraph + "\n\n" |
current_word_count += para_word_count |
if current_chunk: |
chunks.append(current_chunk.strip()) |
return chunks |
def process_document(uploaded_file) -> None: |
"""Process document and store in ChromaDB""" |
try: |
keys_to_clear = ["document_text", "document_chunks", "document_embeddings"] |
for key in keys_to_clear: |
st.session_state.pop(key, None) |
file_text = extract_text_from_file(uploaded_file) |
if not file_text.strip(): |
st.error("The uploaded file contains no valid text.") |
return |
chunks = chunk_text(file_text) |
if not chunks: |
st.error("Failed to split text into chunks.") |
return |
embeddings = [] |
chunk_ids = [] |
progress_bar = st.progress(0) |
for i, chunk in enumerate(chunks): |
chunk_id = str(uuid.uuid4()) |
embedding = generate_embedding_cached(chunk) |
if not any(embedding): |
continue |
embeddings.append(embedding) |
chunk_ids.append(chunk_id) |
progress_bar.progress((i + 1) / len(chunks)) |
if not embeddings: |
st.error("Failed to generate valid embeddings for the document.") |
return |
if embedding_collection is None: |
st.error("ChromaDB collection is not initialized.") |
return |
embedding_collection.add( |
ids=chunk_ids, |
documents=chunks[:len(embeddings)], |
embeddings=embeddings, |
metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))] |
) |
st.session_state.update({ |
"document_text": file_text, |
"document_chunks": chunks[:len(embeddings)], |
"document_embeddings": embeddings, |
"chunk_ids": chunk_ids |
}) |
if not st.session_state.get("doc_processed", False): |
st.success("Document processing complete! You can now start chatting.") |
st.session_state.doc_processed = True |
except Exception as e: |
logger.error(f"Document processing failed: {e}") |
st.error(f"An error occurred while processing the document: {e}") |
def search_query(query: str) -> list[tuple[str, float]]: |
"""Search for relevant document chunks""" |
try: |
query_embedding = generate_embedding_cached(query) |
results = embedding_collection.query( |
query_embeddings=[query_embedding], |
n_results=Config.TOP_N |
) |
results_data = [] |
for i, metadata in enumerate(results["metadatas"]): |
chunk_index = metadata["chunk_index"] |
similarity_score = results["distances"][i] |
results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score)) |
return results_data |
except Exception as e: |
logger.error(f"Search query failed: {e}") |
return [] |
def generate_answer(user_query: str, context: str) -> str: |
"""Generate answer using Palm API""" |
prompt = ( |
f"System: {Config.SYSTEM_PROMPT}\n\n" |
f"Context:\n{context}\n\n" |
f"User: {user_query}\nAssistant:" |
) |
try: |
model = palm.GenerativeModel(Config.GENERATION_MODEL) |
response = model.generate_content(prompt) |
return response.text if hasattr(response, "text") else response |
except Exception as e: |
logger.error(f"Answer generation failed: {e}") |
return "I'm sorry, I encountered an error generating a response." |
def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None): |
"""Save conversation to Firestore""" |
conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations") |
data = { |
"user_question": user_question, |
"assistant_answer": assistant_answer, |
"feedback": feedback, |
"timestamp": firestore.SERVER_TIMESTAMP |
} |
doc_ref = conv_ref.add(data) |
return doc_ref[1].id |
def update_feedback_in_firestore(session_id, conversation_id, feedback): |
"""Update feedback in Firestore""" |
conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id) |
conv_doc.update({"feedback": feedback}) |
def handle_feedback(feedback_val): |
"""Handle user feedback""" |
update_feedback_in_firestore( |
st.session_state.session_id, |
st.session_state.latest_conversation_id, |
feedback_val |
) |
st.session_state.conversations[-1]["feedback"] = feedback_val |
def chat_app(): |
"""Main chat interface""" |
if "conversations" not in st.session_state: |
st.session_state.conversations = [] |
if "session_id" not in st.session_state: |
st.session_state.session_id = str(uuid.uuid4()) |
for conv in st.session_state.conversations: |
with st.chat_message("user"): |
st.write(conv["user_question"]) |
with st.chat_message("assistant"): |
st.write(conv["assistant_answer"]) |
if conv.get("feedback"): |
st.markdown(f"**Feedback:** {conv['feedback']}") |
user_input = st.chat_input("Type your message here") |
if user_input: |
with st.chat_message("user"): |
st.write(user_input) |
results = search_query(user_input) |
context = "\n\n".join([chunk for chunk, score in results]) if results else "" |
answer = generate_answer(user_input, context) |
with st.chat_message("assistant"): |
st.write(answer) |
conversation_id = save_conversation_to_firestore( |
st.session_state.session_id, |
user_question=user_input, |
assistant_answer=answer |
) |
st.session_state.latest_conversation_id = conversation_id |
st.session_state.conversations.append({ |
"user_question": user_input, |
"assistant_answer": answer, |
}) |
if "feedback" not in st.session_state.conversations[-1]: |
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10) |
col1.button("π", key=f"feedback_like_{len(st.session_state.conversations)}", |
on_click=handle_feedback, args=("positive",)) |
col2.button("π", key=f"feedback_dislike_{len(st.session_state.conversations)}", |
on_click=handle_feedback, args=("negative",)) |
def main(): |
"""Main application""" |
st.title("Chat with your files") |
st.sidebar.header("Upload Document") |
uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"]) |
if uploaded_file and not st.session_state.get("doc_processed", False): |
process_document(uploaded_file) |
if "document_text" in st.session_state: |
chat_app() |
else: |
st.info("Please upload and process a document from the sidebar to start chatting.") |
st.markdown( |
""" |
<div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;"> |
Made by Danny.<br> |
Your questions, our response as well as your feedback will be saved for evaluation purposes. |
</div> |
""", |
unsafe_allow_html=True |
) |
if __name__ == "__main__": |
main() |