rag_test / app.py
jessica45's picture
Upload app.py
df2bc4a verified
import streamlit as st
import pdfplumber
import docx
import os
import re
import numpy as np
import google.generativeai as palm
from sklearn.metrics.pairwise import cosine_similarity
import logging
import time
import uuid
import json
import firebase_admin
from firebase_admin import credentials, firestore
from dotenv import load_dotenv
import chromadb
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Configuration class
class Config:
CHUNK_WORDS = 300
EMBEDDING_MODEL = "models/text-embedding-004"
TOP_N = 5
SYSTEM_PROMPT = (
"You are a helpful assistant. Answer the question using the provided context below. "
"Answer based on your knowledge if the context given is not enough."
)
GENERATION_MODEL = "models/gemini-1.5-flash"
# Initialize Firebase
def init_firebase():
"""Initialize Firebase with proper credential handling"""
if not firebase_admin._apps:
try:
firebase_cred = os.getenv("FIREBASE_CRED")
if not firebase_cred:
logger.error("Firebase credentials not found in environment variables")
st.error("Firebase configuration is missing. Please check your .env file.")
st.stop()
cred_dict = json.loads(firebase_cred)
cred = credentials.Certificate(cred_dict)
firebase_admin.initialize_app(cred)
logger.info("Firebase initialized successfully")
except json.JSONDecodeError:
logger.error("Invalid Firebase credentials format")
st.error("Firebase credentials are invalid. Please check your .env file.")
st.stop()
except Exception as e:
logger.error("Firebase initialization failed", exc_info=True)
st.error("Failed to initialize Firebase. Please contact support.")
st.stop()
# Initialize ChromaDB
def init_chroma():
"""Initialize ChromaDB with proper persistence handling"""
try:
persist_directory = "chroma_db"
os.makedirs(persist_directory, exist_ok=True)
client = chromadb.PersistentClient(path=persist_directory)
collection = client.get_or_create_collection(
name="document_embeddings",
metadata={"hnsw:space": "cosine"}
)
logger.info("ChromaDB initialized successfully")
return client, collection
except Exception as e:
logger.error("ChromaDB initialization failed", exc_info=True)
st.error("Failed to initialize ChromaDB. Please check your configuration.")
st.stop()
# Initialize services
init_firebase()
fs_client = firestore.client()
chroma_client, embedding_collection = init_chroma()
# Configure Palm API
API_KEY = os.getenv("GOOGLE_API_KEY")
if not API_KEY:
st.error("Google API key is not configured.")
st.stop()
palm.configure(api_key=API_KEY)
# Utility functions
@st.cache_data(show_spinner=True)
def generate_embedding_cached(text: str) -> list:
"""Generate embeddings with caching"""
logger.info(f"Generating embedding for text: {text[:50]}...")
try:
response = palm.embed_content(
model=Config.EMBEDDING_MODEL,
content=text,
task_type="retrieval_document"
)
if "embedding" not in response or not response["embedding"]:
logger.error("No embedding returned from API")
return [0.0] * 768
embedding = np.array(response["embedding"])
if embedding.ndim == 2:
embedding = embedding.flatten()
return embedding.tolist()
except Exception as e:
logger.error(f"Embedding generation failed: {e}")
return [0.0] * 768
def extract_text_from_file(uploaded_file) -> str:
"""Extract text from various file formats"""
file_name = uploaded_file.name.lower()
if file_name.endswith(".txt"):
return uploaded_file.read().decode("utf-8")
elif file_name.endswith(".pdf"):
with pdfplumber.open(uploaded_file) as pdf:
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()])
elif file_name.endswith(".docx"):
doc = docx.Document(uploaded_file)
return "\n".join([para.text for para in doc.paragraphs])
else:
raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
def chunk_text(text: str) -> list[str]:
"""Split text into manageable chunks"""
max_words = Config.CHUNK_WORDS
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
chunks = []
current_chunk = ""
current_word_count = 0
for paragraph in paragraphs:
para_word_count = len(paragraph.split())
if para_word_count > max_words:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = ""
current_word_count = 0
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
temp_chunk = ""
temp_word_count = 0
for sentence in sentences:
sentence_word_count = len(sentence.split())
if temp_word_count + sentence_word_count > max_words:
if temp_chunk:
chunks.append(temp_chunk.strip())
temp_chunk = sentence + " "
temp_word_count = sentence_word_count
else:
temp_chunk += sentence + " "
temp_word_count += sentence_word_count
if temp_chunk:
chunks.append(temp_chunk.strip())
else:
if current_word_count + para_word_count > max_words:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = paragraph + "\n\n"
current_word_count = para_word_count
else:
current_chunk += paragraph + "\n\n"
current_word_count += para_word_count
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def process_document(uploaded_file) -> None:
"""Process document and store in ChromaDB"""
try:
# Clear existing session state
keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
for key in keys_to_clear:
st.session_state.pop(key, None)
# Extract and validate text
file_text = extract_text_from_file(uploaded_file)
if not file_text.strip():
st.error("The uploaded file contains no valid text.")
return
# Process text into chunks
chunks = chunk_text(file_text)
if not chunks:
st.error("Failed to split text into chunks.")
return
# Generate embeddings
embeddings = []
chunk_ids = []
progress_bar = st.progress(0) # βœ… Correctly initialize progress bar
for i, chunk in enumerate(chunks):
chunk_id = str(uuid.uuid4())
embedding = generate_embedding_cached(chunk)
if not any(embedding): # Ensure embedding is valid
continue
embeddings.append(embedding)
chunk_ids.append(chunk_id)
progress_bar.progress((i + 1) / len(chunks)) # βœ… Update progress bar
if not embeddings:
st.error("Failed to generate valid embeddings for the document.")
return
# Ensure `embedding_collection` is properly initialized
if embedding_collection is None:
st.error("ChromaDB collection is not initialized.")
return
# Save to ChromaDB
embedding_collection.add(
ids=chunk_ids,
documents=chunks[:len(embeddings)],
embeddings=embeddings,
metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))]
)
# Update session state
st.session_state.update({
"document_text": file_text,
"document_chunks": chunks[:len(embeddings)],
"document_embeddings": embeddings,
"chunk_ids": chunk_ids
})
if not st.session_state.get("doc_processed", False):
st.success("Document processing complete! You can now start chatting.")
st.session_state.doc_processed = True
except Exception as e:
logger.error(f"Document processing failed: {e}")
st.error(f"An error occurred while processing the document: {e}")
def search_query(query: str) -> list[tuple[str, float]]:
"""Search for relevant document chunks"""
try:
query_embedding = generate_embedding_cached(query)
results = embedding_collection.query(
query_embeddings=[query_embedding],
n_results=Config.TOP_N
)
results_data = []
for i, metadata in enumerate(results["metadatas"]):
chunk_index = metadata["chunk_index"]
similarity_score = results["distances"][i]
results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score))
return results_data
except Exception as e:
logger.error(f"Search query failed: {e}")
return []
def generate_answer(user_query: str, context: str) -> str:
"""Generate answer using Palm API"""
prompt = (
f"System: {Config.SYSTEM_PROMPT}\n\n"
f"Context:\n{context}\n\n"
f"User: {user_query}\nAssistant:"
)
try:
model = palm.GenerativeModel(Config.GENERATION_MODEL)
response = model.generate_content(prompt)
return response.text if hasattr(response, "text") else response
except Exception as e:
logger.error(f"Answer generation failed: {e}")
return "I'm sorry, I encountered an error generating a response."
# Firebase functions
def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None):
"""Save conversation to Firestore"""
conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations")
data = {
"user_question": user_question,
"assistant_answer": assistant_answer,
"feedback": feedback,
"timestamp": firestore.SERVER_TIMESTAMP
}
doc_ref = conv_ref.add(data)
return doc_ref[1].id
def update_feedback_in_firestore(session_id, conversation_id, feedback):
"""Update feedback in Firestore"""
conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id)
conv_doc.update({"feedback": feedback})
def handle_feedback(feedback_val):
"""Handle user feedback"""
update_feedback_in_firestore(
st.session_state.session_id,
st.session_state.latest_conversation_id,
feedback_val
)
st.session_state.conversations[-1]["feedback"] = feedback_val
# Chat interface
def chat_app():
"""Main chat interface"""
if "conversations" not in st.session_state:
st.session_state.conversations = []
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
# Display conversation history
for conv in st.session_state.conversations:
with st.chat_message("user"):
st.write(conv["user_question"])
with st.chat_message("assistant"):
st.write(conv["assistant_answer"])
if conv.get("feedback"):
st.markdown(f"**Feedback:** {conv['feedback']}")
# Handle new user input
user_input = st.chat_input("Type your message here")
if user_input:
with st.chat_message("user"):
st.write(user_input)
results = search_query(user_input)
context = "\n\n".join([chunk for chunk, score in results]) if results else ""
answer = generate_answer(user_input, context)
with st.chat_message("assistant"):
st.write(answer)
conversation_id = save_conversation_to_firestore(
st.session_state.session_id,
user_question=user_input,
assistant_answer=answer
)
st.session_state.latest_conversation_id = conversation_id
st.session_state.conversations.append({
"user_question": user_input,
"assistant_answer": answer,
})
# Add feedback buttons
if "feedback" not in st.session_state.conversations[-1]:
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10)
col1.button("πŸ‘", key=f"feedback_like_{len(st.session_state.conversations)}",
on_click=handle_feedback, args=("positive",))
col2.button("πŸ‘Ž", key=f"feedback_dislike_{len(st.session_state.conversations)}",
on_click=handle_feedback, args=("negative",))
def main():
"""Main application"""
st.title("Chat with your files")
# Sidebar for file upload
st.sidebar.header("Upload Document")
uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
if uploaded_file and not st.session_state.get("doc_processed", False):
process_document(uploaded_file)
if "document_text" in st.session_state:
chat_app()
else:
st.info("Please upload and process a document from the sidebar to start chatting.")
# Footer
st.markdown(
"""
<div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;">
Made by Danny.<br>
Your questions, our response as well as your feedback will be saved for evaluation purposes.
</div>
""",
unsafe_allow_html=True
)
if __name__ == "__main__":
main()