|
import streamlit as st |
|
import pdfplumber |
|
import docx |
|
import os |
|
import re |
|
import numpy as np |
|
import google.generativeai as palm |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import logging |
|
import time |
|
import uuid |
|
import json |
|
import firebase_admin |
|
from firebase_admin import credentials, firestore |
|
from dotenv import load_dotenv |
|
import chromadb |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s', |
|
handlers=[logging.StreamHandler()] |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
class Config: |
|
CHUNK_WORDS = 300 |
|
EMBEDDING_MODEL = "models/text-embedding-004" |
|
TOP_N = 5 |
|
SYSTEM_PROMPT = ( |
|
"You are a helpful assistant. Answer the question using the provided context below. " |
|
"Answer based on your knowledge if the context given is not enough." |
|
) |
|
GENERATION_MODEL = "models/gemini-1.5-flash" |
|
|
|
|
|
def init_firebase(): |
|
"""Initialize Firebase with proper credential handling""" |
|
if not firebase_admin._apps: |
|
try: |
|
firebase_cred = os.getenv("FIREBASE_CRED") |
|
if not firebase_cred: |
|
logger.error("Firebase credentials not found in environment variables") |
|
st.error("Firebase configuration is missing. Please check your .env file.") |
|
st.stop() |
|
|
|
cred_dict = json.loads(firebase_cred) |
|
cred = credentials.Certificate(cred_dict) |
|
firebase_admin.initialize_app(cred) |
|
logger.info("Firebase initialized successfully") |
|
|
|
except json.JSONDecodeError: |
|
logger.error("Invalid Firebase credentials format") |
|
st.error("Firebase credentials are invalid. Please check your .env file.") |
|
st.stop() |
|
except Exception as e: |
|
logger.error("Firebase initialization failed", exc_info=True) |
|
st.error("Failed to initialize Firebase. Please contact support.") |
|
st.stop() |
|
|
|
|
|
def init_chroma(): |
|
"""Initialize ChromaDB with proper persistence handling""" |
|
try: |
|
persist_directory = "chroma_db" |
|
os.makedirs(persist_directory, exist_ok=True) |
|
|
|
client = chromadb.PersistentClient(path=persist_directory) |
|
collection = client.get_or_create_collection( |
|
name="document_embeddings", |
|
metadata={"hnsw:space": "cosine"} |
|
) |
|
logger.info("ChromaDB initialized successfully") |
|
return client, collection |
|
except Exception as e: |
|
logger.error("ChromaDB initialization failed", exc_info=True) |
|
st.error("Failed to initialize ChromaDB. Please check your configuration.") |
|
st.stop() |
|
|
|
|
|
init_firebase() |
|
fs_client = firestore.client() |
|
chroma_client, embedding_collection = init_chroma() |
|
|
|
|
|
API_KEY = os.getenv("GOOGLE_API_KEY") |
|
if not API_KEY: |
|
st.error("Google API key is not configured.") |
|
st.stop() |
|
palm.configure(api_key=API_KEY) |
|
|
|
|
|
@st.cache_data(show_spinner=True) |
|
def generate_embedding_cached(text: str) -> list: |
|
"""Generate embeddings with caching""" |
|
logger.info(f"Generating embedding for text: {text[:50]}...") |
|
try: |
|
response = palm.embed_content( |
|
model=Config.EMBEDDING_MODEL, |
|
content=text, |
|
task_type="retrieval_document" |
|
) |
|
if "embedding" not in response or not response["embedding"]: |
|
logger.error("No embedding returned from API") |
|
return [0.0] * 768 |
|
|
|
embedding = np.array(response["embedding"]) |
|
if embedding.ndim == 2: |
|
embedding = embedding.flatten() |
|
return embedding.tolist() |
|
except Exception as e: |
|
logger.error(f"Embedding generation failed: {e}") |
|
return [0.0] * 768 |
|
|
|
def extract_text_from_file(uploaded_file) -> str: |
|
"""Extract text from various file formats""" |
|
file_name = uploaded_file.name.lower() |
|
|
|
if file_name.endswith(".txt"): |
|
return uploaded_file.read().decode("utf-8") |
|
elif file_name.endswith(".pdf"): |
|
with pdfplumber.open(uploaded_file) as pdf: |
|
return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()]) |
|
elif file_name.endswith(".docx"): |
|
doc = docx.Document(uploaded_file) |
|
return "\n".join([para.text for para in doc.paragraphs]) |
|
else: |
|
raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.") |
|
|
|
def chunk_text(text: str) -> list[str]: |
|
"""Split text into manageable chunks""" |
|
max_words = Config.CHUNK_WORDS |
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] |
|
chunks = [] |
|
current_chunk = "" |
|
current_word_count = 0 |
|
|
|
for paragraph in paragraphs: |
|
para_word_count = len(paragraph.split()) |
|
|
|
if para_word_count > max_words: |
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
current_chunk = "" |
|
current_word_count = 0 |
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
|
temp_chunk = "" |
|
temp_word_count = 0 |
|
|
|
for sentence in sentences: |
|
sentence_word_count = len(sentence.split()) |
|
if temp_word_count + sentence_word_count > max_words: |
|
if temp_chunk: |
|
chunks.append(temp_chunk.strip()) |
|
temp_chunk = sentence + " " |
|
temp_word_count = sentence_word_count |
|
else: |
|
temp_chunk += sentence + " " |
|
temp_word_count += sentence_word_count |
|
|
|
if temp_chunk: |
|
chunks.append(temp_chunk.strip()) |
|
else: |
|
if current_word_count + para_word_count > max_words: |
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
current_chunk = paragraph + "\n\n" |
|
current_word_count = para_word_count |
|
else: |
|
current_chunk += paragraph + "\n\n" |
|
current_word_count += para_word_count |
|
|
|
if current_chunk: |
|
chunks.append(current_chunk.strip()) |
|
return chunks |
|
|
|
def process_document(uploaded_file) -> None: |
|
"""Process document and store in ChromaDB""" |
|
try: |
|
|
|
keys_to_clear = ["document_text", "document_chunks", "document_embeddings"] |
|
for key in keys_to_clear: |
|
st.session_state.pop(key, None) |
|
|
|
|
|
file_text = extract_text_from_file(uploaded_file) |
|
if not file_text.strip(): |
|
st.error("The uploaded file contains no valid text.") |
|
return |
|
|
|
|
|
chunks = chunk_text(file_text) |
|
if not chunks: |
|
st.error("Failed to split text into chunks.") |
|
return |
|
|
|
|
|
embeddings = [] |
|
chunk_ids = [] |
|
|
|
progress_bar = st.progress(0) |
|
|
|
for i, chunk in enumerate(chunks): |
|
chunk_id = str(uuid.uuid4()) |
|
embedding = generate_embedding_cached(chunk) |
|
|
|
if not any(embedding): |
|
continue |
|
|
|
embeddings.append(embedding) |
|
chunk_ids.append(chunk_id) |
|
progress_bar.progress((i + 1) / len(chunks)) |
|
|
|
if not embeddings: |
|
st.error("Failed to generate valid embeddings for the document.") |
|
return |
|
|
|
|
|
if embedding_collection is None: |
|
st.error("ChromaDB collection is not initialized.") |
|
return |
|
|
|
|
|
embedding_collection.add( |
|
ids=chunk_ids, |
|
documents=chunks[:len(embeddings)], |
|
embeddings=embeddings, |
|
metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))] |
|
) |
|
|
|
|
|
st.session_state.update({ |
|
"document_text": file_text, |
|
"document_chunks": chunks[:len(embeddings)], |
|
"document_embeddings": embeddings, |
|
"chunk_ids": chunk_ids |
|
}) |
|
|
|
if not st.session_state.get("doc_processed", False): |
|
st.success("Document processing complete! You can now start chatting.") |
|
st.session_state.doc_processed = True |
|
|
|
except Exception as e: |
|
logger.error(f"Document processing failed: {e}") |
|
st.error(f"An error occurred while processing the document: {e}") |
|
|
|
def search_query(query: str) -> list[tuple[str, float]]: |
|
"""Search for relevant document chunks""" |
|
try: |
|
query_embedding = generate_embedding_cached(query) |
|
|
|
results = embedding_collection.query( |
|
query_embeddings=[query_embedding], |
|
n_results=Config.TOP_N |
|
) |
|
|
|
results_data = [] |
|
for i, metadata in enumerate(results["metadatas"]): |
|
chunk_index = metadata["chunk_index"] |
|
similarity_score = results["distances"][i] |
|
results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score)) |
|
|
|
return results_data |
|
except Exception as e: |
|
logger.error(f"Search query failed: {e}") |
|
return [] |
|
|
|
def generate_answer(user_query: str, context: str) -> str: |
|
"""Generate answer using Palm API""" |
|
prompt = ( |
|
f"System: {Config.SYSTEM_PROMPT}\n\n" |
|
f"Context:\n{context}\n\n" |
|
f"User: {user_query}\nAssistant:" |
|
) |
|
try: |
|
model = palm.GenerativeModel(Config.GENERATION_MODEL) |
|
response = model.generate_content(prompt) |
|
return response.text if hasattr(response, "text") else response |
|
except Exception as e: |
|
logger.error(f"Answer generation failed: {e}") |
|
return "I'm sorry, I encountered an error generating a response." |
|
|
|
|
|
def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None): |
|
"""Save conversation to Firestore""" |
|
conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations") |
|
data = { |
|
"user_question": user_question, |
|
"assistant_answer": assistant_answer, |
|
"feedback": feedback, |
|
"timestamp": firestore.SERVER_TIMESTAMP |
|
} |
|
doc_ref = conv_ref.add(data) |
|
return doc_ref[1].id |
|
|
|
def update_feedback_in_firestore(session_id, conversation_id, feedback): |
|
"""Update feedback in Firestore""" |
|
conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id) |
|
conv_doc.update({"feedback": feedback}) |
|
|
|
def handle_feedback(feedback_val): |
|
"""Handle user feedback""" |
|
update_feedback_in_firestore( |
|
st.session_state.session_id, |
|
st.session_state.latest_conversation_id, |
|
feedback_val |
|
) |
|
st.session_state.conversations[-1]["feedback"] = feedback_val |
|
|
|
|
|
def chat_app(): |
|
"""Main chat interface""" |
|
if "conversations" not in st.session_state: |
|
st.session_state.conversations = [] |
|
if "session_id" not in st.session_state: |
|
st.session_state.session_id = str(uuid.uuid4()) |
|
|
|
|
|
for conv in st.session_state.conversations: |
|
with st.chat_message("user"): |
|
st.write(conv["user_question"]) |
|
with st.chat_message("assistant"): |
|
st.write(conv["assistant_answer"]) |
|
if conv.get("feedback"): |
|
st.markdown(f"**Feedback:** {conv['feedback']}") |
|
|
|
|
|
user_input = st.chat_input("Type your message here") |
|
if user_input: |
|
with st.chat_message("user"): |
|
st.write(user_input) |
|
|
|
results = search_query(user_input) |
|
context = "\n\n".join([chunk for chunk, score in results]) if results else "" |
|
answer = generate_answer(user_input, context) |
|
|
|
with st.chat_message("assistant"): |
|
st.write(answer) |
|
|
|
conversation_id = save_conversation_to_firestore( |
|
st.session_state.session_id, |
|
user_question=user_input, |
|
assistant_answer=answer |
|
) |
|
st.session_state.latest_conversation_id = conversation_id |
|
st.session_state.conversations.append({ |
|
"user_question": user_input, |
|
"assistant_answer": answer, |
|
}) |
|
|
|
|
|
if "feedback" not in st.session_state.conversations[-1]: |
|
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10) |
|
col1.button("π", key=f"feedback_like_{len(st.session_state.conversations)}", |
|
on_click=handle_feedback, args=("positive",)) |
|
col2.button("π", key=f"feedback_dislike_{len(st.session_state.conversations)}", |
|
on_click=handle_feedback, args=("negative",)) |
|
|
|
def main(): |
|
"""Main application""" |
|
st.title("Chat with your files") |
|
|
|
|
|
st.sidebar.header("Upload Document") |
|
uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"]) |
|
|
|
if uploaded_file and not st.session_state.get("doc_processed", False): |
|
process_document(uploaded_file) |
|
|
|
if "document_text" in st.session_state: |
|
chat_app() |
|
else: |
|
st.info("Please upload and process a document from the sidebar to start chatting.") |
|
|
|
|
|
st.markdown( |
|
""" |
|
<div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;"> |
|
Made by Danny.<br> |
|
Your questions, our response as well as your feedback will be saved for evaluation purposes. |
|
</div> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |