import streamlit as st import pdfplumber import docx import os import re import numpy as np import google.generativeai as palm from sklearn.metrics.pairwise import cosine_similarity import logging import time import uuid import json import firebase_admin from firebase_admin import credentials, firestore from dotenv import load_dotenv import chromadb # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Load environment variables load_dotenv() # Configuration class class Config: CHUNK_WORDS = 300 EMBEDDING_MODEL = "models/text-embedding-004" TOP_N = 5 SYSTEM_PROMPT = ( "You are a helpful assistant. Answer the question using the provided context below. " "Answer based on your knowledge if the context given is not enough." ) GENERATION_MODEL = "models/gemini-1.5-flash" # Initialize Firebase def init_firebase(): """Initialize Firebase with proper credential handling""" if not firebase_admin._apps: try: firebase_cred = os.getenv("FIREBASE_CRED") if not firebase_cred: logger.error("Firebase credentials not found in environment variables") st.error("Firebase configuration is missing. Please check your .env file.") st.stop() cred_dict = json.loads(firebase_cred) cred = credentials.Certificate(cred_dict) firebase_admin.initialize_app(cred) logger.info("Firebase initialized successfully") except json.JSONDecodeError: logger.error("Invalid Firebase credentials format") st.error("Firebase credentials are invalid. Please check your .env file.") st.stop() except Exception as e: logger.error("Firebase initialization failed", exc_info=True) st.error("Failed to initialize Firebase. Please contact support.") st.stop() # Initialize ChromaDB def init_chroma(): """Initialize ChromaDB with proper persistence handling""" try: persist_directory = "chroma_db" os.makedirs(persist_directory, exist_ok=True) client = chromadb.PersistentClient(path=persist_directory) collection = client.get_or_create_collection( name="document_embeddings", metadata={"hnsw:space": "cosine"} ) logger.info("ChromaDB initialized successfully") return client, collection except Exception as e: logger.error("ChromaDB initialization failed", exc_info=True) st.error("Failed to initialize ChromaDB. Please check your configuration.") st.stop() # Initialize services init_firebase() fs_client = firestore.client() chroma_client, embedding_collection = init_chroma() # Configure Palm API API_KEY = os.getenv("GOOGLE_API_KEY") if not API_KEY: st.error("Google API key is not configured.") st.stop() palm.configure(api_key=API_KEY) # Utility functions @st.cache_data(show_spinner=True) def generate_embedding_cached(text: str) -> list: """Generate embeddings with caching""" logger.info(f"Generating embedding for text: {text[:50]}...") try: response = palm.embed_content( model=Config.EMBEDDING_MODEL, content=text, task_type="retrieval_document" ) if "embedding" not in response or not response["embedding"]: logger.error("No embedding returned from API") return [0.0] * 768 embedding = np.array(response["embedding"]) if embedding.ndim == 2: embedding = embedding.flatten() return embedding.tolist() except Exception as e: logger.error(f"Embedding generation failed: {e}") return [0.0] * 768 def extract_text_from_file(uploaded_file) -> str: """Extract text from various file formats""" file_name = uploaded_file.name.lower() if file_name.endswith(".txt"): return uploaded_file.read().decode("utf-8") elif file_name.endswith(".pdf"): with pdfplumber.open(uploaded_file) as pdf: return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()]) elif file_name.endswith(".docx"): doc = docx.Document(uploaded_file) return "\n".join([para.text for para in doc.paragraphs]) else: raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.") def chunk_text(text: str) -> list[str]: """Split text into manageable chunks""" max_words = Config.CHUNK_WORDS paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] chunks = [] current_chunk = "" current_word_count = 0 for paragraph in paragraphs: para_word_count = len(paragraph.split()) if para_word_count > max_words: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = "" current_word_count = 0 sentences = re.split(r'(?<=[.!?])\s+', paragraph) temp_chunk = "" temp_word_count = 0 for sentence in sentences: sentence_word_count = len(sentence.split()) if temp_word_count + sentence_word_count > max_words: if temp_chunk: chunks.append(temp_chunk.strip()) temp_chunk = sentence + " " temp_word_count = sentence_word_count else: temp_chunk += sentence + " " temp_word_count += sentence_word_count if temp_chunk: chunks.append(temp_chunk.strip()) else: if current_word_count + para_word_count > max_words: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = paragraph + "\n\n" current_word_count = para_word_count else: current_chunk += paragraph + "\n\n" current_word_count += para_word_count if current_chunk: chunks.append(current_chunk.strip()) return chunks def process_document(uploaded_file) -> None: """Process document and store in ChromaDB""" try: # Clear existing session state keys_to_clear = ["document_text", "document_chunks", "document_embeddings"] for key in keys_to_clear: st.session_state.pop(key, None) # Extract and validate text file_text = extract_text_from_file(uploaded_file) if not file_text.strip(): st.error("The uploaded file contains no valid text.") return # Process text into chunks chunks = chunk_text(file_text) if not chunks: st.error("Failed to split text into chunks.") return # Generate embeddings embeddings = [] chunk_ids = [] progress_bar = st.progress(0) # ✅ Correctly initialize progress bar for i, chunk in enumerate(chunks): chunk_id = str(uuid.uuid4()) embedding = generate_embedding_cached(chunk) if not any(embedding): # Ensure embedding is valid continue embeddings.append(embedding) chunk_ids.append(chunk_id) progress_bar.progress((i + 1) / len(chunks)) # ✅ Update progress bar if not embeddings: st.error("Failed to generate valid embeddings for the document.") return # Ensure `embedding_collection` is properly initialized if embedding_collection is None: st.error("ChromaDB collection is not initialized.") return # Save to ChromaDB embedding_collection.add( ids=chunk_ids, documents=chunks[:len(embeddings)], embeddings=embeddings, metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))] ) # Update session state st.session_state.update({ "document_text": file_text, "document_chunks": chunks[:len(embeddings)], "document_embeddings": embeddings, "chunk_ids": chunk_ids }) if not st.session_state.get("doc_processed", False): st.success("Document processing complete! You can now start chatting.") st.session_state.doc_processed = True except Exception as e: logger.error(f"Document processing failed: {e}") st.error(f"An error occurred while processing the document: {e}") def search_query(query: str) -> list[tuple[str, float]]: """Search for relevant document chunks""" try: query_embedding = generate_embedding_cached(query) results = embedding_collection.query( query_embeddings=[query_embedding], n_results=Config.TOP_N ) results_data = [] for i, metadata in enumerate(results["metadatas"]): chunk_index = metadata["chunk_index"] similarity_score = results["distances"][i] results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score)) return results_data except Exception as e: logger.error(f"Search query failed: {e}") return [] def generate_answer(user_query: str, context: str) -> str: """Generate answer using Palm API""" prompt = ( f"System: {Config.SYSTEM_PROMPT}\n\n" f"Context:\n{context}\n\n" f"User: {user_query}\nAssistant:" ) try: model = palm.GenerativeModel(Config.GENERATION_MODEL) response = model.generate_content(prompt) return response.text if hasattr(response, "text") else response except Exception as e: logger.error(f"Answer generation failed: {e}") return "I'm sorry, I encountered an error generating a response." # Firebase functions def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None): """Save conversation to Firestore""" conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations") data = { "user_question": user_question, "assistant_answer": assistant_answer, "feedback": feedback, "timestamp": firestore.SERVER_TIMESTAMP } doc_ref = conv_ref.add(data) return doc_ref[1].id def update_feedback_in_firestore(session_id, conversation_id, feedback): """Update feedback in Firestore""" conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id) conv_doc.update({"feedback": feedback}) def handle_feedback(feedback_val): """Handle user feedback""" update_feedback_in_firestore( st.session_state.session_id, st.session_state.latest_conversation_id, feedback_val ) st.session_state.conversations[-1]["feedback"] = feedback_val # Chat interface def chat_app(): """Main chat interface""" if "conversations" not in st.session_state: st.session_state.conversations = [] if "session_id" not in st.session_state: st.session_state.session_id = str(uuid.uuid4()) # Display conversation history for conv in st.session_state.conversations: with st.chat_message("user"): st.write(conv["user_question"]) with st.chat_message("assistant"): st.write(conv["assistant_answer"]) if conv.get("feedback"): st.markdown(f"**Feedback:** {conv['feedback']}") # Handle new user input user_input = st.chat_input("Type your message here") if user_input: with st.chat_message("user"): st.write(user_input) results = search_query(user_input) context = "\n\n".join([chunk for chunk, score in results]) if results else "" answer = generate_answer(user_input, context) with st.chat_message("assistant"): st.write(answer) conversation_id = save_conversation_to_firestore( st.session_state.session_id, user_question=user_input, assistant_answer=answer ) st.session_state.latest_conversation_id = conversation_id st.session_state.conversations.append({ "user_question": user_input, "assistant_answer": answer, }) # Add feedback buttons if "feedback" not in st.session_state.conversations[-1]: col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10) col1.button("👍", key=f"feedback_like_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("positive",)) col2.button("👎", key=f"feedback_dislike_{len(st.session_state.conversations)}", on_click=handle_feedback, args=("negative",)) def main(): """Main application""" st.title("Chat with your files") # Sidebar for file upload st.sidebar.header("Upload Document") uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"]) if uploaded_file and not st.session_state.get("doc_processed", False): process_document(uploaded_file) if "document_text" in st.session_state: chat_app() else: st.info("Please upload and process a document from the sidebar to start chatting.") # Footer st.markdown( """