jessica45 commited on
Commit
df2bc4a
Β·
verified Β·
1 Parent(s): 171e1a7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +392 -0
app.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pdfplumber
3
+ import docx
4
+ import os
5
+ import re
6
+ import numpy as np
7
+ import google.generativeai as palm
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import logging
10
+ import time
11
+ import uuid
12
+ import json
13
+ import firebase_admin
14
+ from firebase_admin import credentials, firestore
15
+ from dotenv import load_dotenv
16
+ import chromadb
17
+
18
+ # Configure logging
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(levelname)s - %(message)s',
22
+ handlers=[logging.StreamHandler()]
23
+ )
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Load environment variables
27
+ load_dotenv()
28
+
29
+ # Configuration class
30
+ class Config:
31
+ CHUNK_WORDS = 300
32
+ EMBEDDING_MODEL = "models/text-embedding-004"
33
+ TOP_N = 5
34
+ SYSTEM_PROMPT = (
35
+ "You are a helpful assistant. Answer the question using the provided context below. "
36
+ "Answer based on your knowledge if the context given is not enough."
37
+ )
38
+ GENERATION_MODEL = "models/gemini-1.5-flash"
39
+
40
+ # Initialize Firebase
41
+ def init_firebase():
42
+ """Initialize Firebase with proper credential handling"""
43
+ if not firebase_admin._apps:
44
+ try:
45
+ firebase_cred = os.getenv("FIREBASE_CRED")
46
+ if not firebase_cred:
47
+ logger.error("Firebase credentials not found in environment variables")
48
+ st.error("Firebase configuration is missing. Please check your .env file.")
49
+ st.stop()
50
+
51
+ cred_dict = json.loads(firebase_cred)
52
+ cred = credentials.Certificate(cred_dict)
53
+ firebase_admin.initialize_app(cred)
54
+ logger.info("Firebase initialized successfully")
55
+
56
+ except json.JSONDecodeError:
57
+ logger.error("Invalid Firebase credentials format")
58
+ st.error("Firebase credentials are invalid. Please check your .env file.")
59
+ st.stop()
60
+ except Exception as e:
61
+ logger.error("Firebase initialization failed", exc_info=True)
62
+ st.error("Failed to initialize Firebase. Please contact support.")
63
+ st.stop()
64
+
65
+ # Initialize ChromaDB
66
+ def init_chroma():
67
+ """Initialize ChromaDB with proper persistence handling"""
68
+ try:
69
+ persist_directory = "chroma_db"
70
+ os.makedirs(persist_directory, exist_ok=True)
71
+
72
+ client = chromadb.PersistentClient(path=persist_directory)
73
+ collection = client.get_or_create_collection(
74
+ name="document_embeddings",
75
+ metadata={"hnsw:space": "cosine"}
76
+ )
77
+ logger.info("ChromaDB initialized successfully")
78
+ return client, collection
79
+ except Exception as e:
80
+ logger.error("ChromaDB initialization failed", exc_info=True)
81
+ st.error("Failed to initialize ChromaDB. Please check your configuration.")
82
+ st.stop()
83
+
84
+ # Initialize services
85
+ init_firebase()
86
+ fs_client = firestore.client()
87
+ chroma_client, embedding_collection = init_chroma()
88
+
89
+ # Configure Palm API
90
+ API_KEY = os.getenv("GOOGLE_API_KEY")
91
+ if not API_KEY:
92
+ st.error("Google API key is not configured.")
93
+ st.stop()
94
+ palm.configure(api_key=API_KEY)
95
+
96
+ # Utility functions
97
+ @st.cache_data(show_spinner=True)
98
+ def generate_embedding_cached(text: str) -> list:
99
+ """Generate embeddings with caching"""
100
+ logger.info(f"Generating embedding for text: {text[:50]}...")
101
+ try:
102
+ response = palm.embed_content(
103
+ model=Config.EMBEDDING_MODEL,
104
+ content=text,
105
+ task_type="retrieval_document"
106
+ )
107
+ if "embedding" not in response or not response["embedding"]:
108
+ logger.error("No embedding returned from API")
109
+ return [0.0] * 768
110
+
111
+ embedding = np.array(response["embedding"])
112
+ if embedding.ndim == 2:
113
+ embedding = embedding.flatten()
114
+ return embedding.tolist()
115
+ except Exception as e:
116
+ logger.error(f"Embedding generation failed: {e}")
117
+ return [0.0] * 768
118
+
119
+ def extract_text_from_file(uploaded_file) -> str:
120
+ """Extract text from various file formats"""
121
+ file_name = uploaded_file.name.lower()
122
+
123
+ if file_name.endswith(".txt"):
124
+ return uploaded_file.read().decode("utf-8")
125
+ elif file_name.endswith(".pdf"):
126
+ with pdfplumber.open(uploaded_file) as pdf:
127
+ return "\n".join([page.extract_text() for page in pdf.pages if page.extract_text()])
128
+ elif file_name.endswith(".docx"):
129
+ doc = docx.Document(uploaded_file)
130
+ return "\n".join([para.text for para in doc.paragraphs])
131
+ else:
132
+ raise ValueError("Unsupported file type. Please upload a .txt, .pdf, or .docx file.")
133
+
134
+ def chunk_text(text: str) -> list[str]:
135
+ """Split text into manageable chunks"""
136
+ max_words = Config.CHUNK_WORDS
137
+ paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
138
+ chunks = []
139
+ current_chunk = ""
140
+ current_word_count = 0
141
+
142
+ for paragraph in paragraphs:
143
+ para_word_count = len(paragraph.split())
144
+
145
+ if para_word_count > max_words:
146
+ if current_chunk:
147
+ chunks.append(current_chunk.strip())
148
+ current_chunk = ""
149
+ current_word_count = 0
150
+
151
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
152
+ temp_chunk = ""
153
+ temp_word_count = 0
154
+
155
+ for sentence in sentences:
156
+ sentence_word_count = len(sentence.split())
157
+ if temp_word_count + sentence_word_count > max_words:
158
+ if temp_chunk:
159
+ chunks.append(temp_chunk.strip())
160
+ temp_chunk = sentence + " "
161
+ temp_word_count = sentence_word_count
162
+ else:
163
+ temp_chunk += sentence + " "
164
+ temp_word_count += sentence_word_count
165
+
166
+ if temp_chunk:
167
+ chunks.append(temp_chunk.strip())
168
+ else:
169
+ if current_word_count + para_word_count > max_words:
170
+ if current_chunk:
171
+ chunks.append(current_chunk.strip())
172
+ current_chunk = paragraph + "\n\n"
173
+ current_word_count = para_word_count
174
+ else:
175
+ current_chunk += paragraph + "\n\n"
176
+ current_word_count += para_word_count
177
+
178
+ if current_chunk:
179
+ chunks.append(current_chunk.strip())
180
+ return chunks
181
+
182
+ def process_document(uploaded_file) -> None:
183
+ """Process document and store in ChromaDB"""
184
+ try:
185
+ # Clear existing session state
186
+ keys_to_clear = ["document_text", "document_chunks", "document_embeddings"]
187
+ for key in keys_to_clear:
188
+ st.session_state.pop(key, None)
189
+
190
+ # Extract and validate text
191
+ file_text = extract_text_from_file(uploaded_file)
192
+ if not file_text.strip():
193
+ st.error("The uploaded file contains no valid text.")
194
+ return
195
+
196
+ # Process text into chunks
197
+ chunks = chunk_text(file_text)
198
+ if not chunks:
199
+ st.error("Failed to split text into chunks.")
200
+ return
201
+
202
+ # Generate embeddings
203
+ embeddings = []
204
+ chunk_ids = []
205
+
206
+ progress_bar = st.progress(0) # βœ… Correctly initialize progress bar
207
+
208
+ for i, chunk in enumerate(chunks):
209
+ chunk_id = str(uuid.uuid4())
210
+ embedding = generate_embedding_cached(chunk)
211
+
212
+ if not any(embedding): # Ensure embedding is valid
213
+ continue
214
+
215
+ embeddings.append(embedding)
216
+ chunk_ids.append(chunk_id)
217
+ progress_bar.progress((i + 1) / len(chunks)) # βœ… Update progress bar
218
+
219
+ if not embeddings:
220
+ st.error("Failed to generate valid embeddings for the document.")
221
+ return
222
+
223
+ # Ensure `embedding_collection` is properly initialized
224
+ if embedding_collection is None:
225
+ st.error("ChromaDB collection is not initialized.")
226
+ return
227
+
228
+ # Save to ChromaDB
229
+ embedding_collection.add(
230
+ ids=chunk_ids,
231
+ documents=chunks[:len(embeddings)],
232
+ embeddings=embeddings,
233
+ metadatas=[{"chunk_index": idx} for idx in range(len(embeddings))]
234
+ )
235
+
236
+ # Update session state
237
+ st.session_state.update({
238
+ "document_text": file_text,
239
+ "document_chunks": chunks[:len(embeddings)],
240
+ "document_embeddings": embeddings,
241
+ "chunk_ids": chunk_ids
242
+ })
243
+
244
+ if not st.session_state.get("doc_processed", False):
245
+ st.success("Document processing complete! You can now start chatting.")
246
+ st.session_state.doc_processed = True
247
+
248
+ except Exception as e:
249
+ logger.error(f"Document processing failed: {e}")
250
+ st.error(f"An error occurred while processing the document: {e}")
251
+
252
+ def search_query(query: str) -> list[tuple[str, float]]:
253
+ """Search for relevant document chunks"""
254
+ try:
255
+ query_embedding = generate_embedding_cached(query)
256
+
257
+ results = embedding_collection.query(
258
+ query_embeddings=[query_embedding],
259
+ n_results=Config.TOP_N
260
+ )
261
+
262
+ results_data = []
263
+ for i, metadata in enumerate(results["metadatas"]):
264
+ chunk_index = metadata["chunk_index"]
265
+ similarity_score = results["distances"][i]
266
+ results_data.append((st.session_state["document_chunks"][chunk_index], similarity_score))
267
+
268
+ return results_data
269
+ except Exception as e:
270
+ logger.error(f"Search query failed: {e}")
271
+ return []
272
+
273
+ def generate_answer(user_query: str, context: str) -> str:
274
+ """Generate answer using Palm API"""
275
+ prompt = (
276
+ f"System: {Config.SYSTEM_PROMPT}\n\n"
277
+ f"Context:\n{context}\n\n"
278
+ f"User: {user_query}\nAssistant:"
279
+ )
280
+ try:
281
+ model = palm.GenerativeModel(Config.GENERATION_MODEL)
282
+ response = model.generate_content(prompt)
283
+ return response.text if hasattr(response, "text") else response
284
+ except Exception as e:
285
+ logger.error(f"Answer generation failed: {e}")
286
+ return "I'm sorry, I encountered an error generating a response."
287
+
288
+ # Firebase functions
289
+ def save_conversation_to_firestore(session_id, user_question, assistant_answer, feedback=None):
290
+ """Save conversation to Firestore"""
291
+ conv_ref = fs_client.collection("sessions").document(session_id).collection("conversations")
292
+ data = {
293
+ "user_question": user_question,
294
+ "assistant_answer": assistant_answer,
295
+ "feedback": feedback,
296
+ "timestamp": firestore.SERVER_TIMESTAMP
297
+ }
298
+ doc_ref = conv_ref.add(data)
299
+ return doc_ref[1].id
300
+
301
+ def update_feedback_in_firestore(session_id, conversation_id, feedback):
302
+ """Update feedback in Firestore"""
303
+ conv_doc = fs_client.collection("sessions").document(session_id).collection("conversations").document(conversation_id)
304
+ conv_doc.update({"feedback": feedback})
305
+
306
+ def handle_feedback(feedback_val):
307
+ """Handle user feedback"""
308
+ update_feedback_in_firestore(
309
+ st.session_state.session_id,
310
+ st.session_state.latest_conversation_id,
311
+ feedback_val
312
+ )
313
+ st.session_state.conversations[-1]["feedback"] = feedback_val
314
+
315
+ # Chat interface
316
+ def chat_app():
317
+ """Main chat interface"""
318
+ if "conversations" not in st.session_state:
319
+ st.session_state.conversations = []
320
+ if "session_id" not in st.session_state:
321
+ st.session_state.session_id = str(uuid.uuid4())
322
+
323
+ # Display conversation history
324
+ for conv in st.session_state.conversations:
325
+ with st.chat_message("user"):
326
+ st.write(conv["user_question"])
327
+ with st.chat_message("assistant"):
328
+ st.write(conv["assistant_answer"])
329
+ if conv.get("feedback"):
330
+ st.markdown(f"**Feedback:** {conv['feedback']}")
331
+
332
+ # Handle new user input
333
+ user_input = st.chat_input("Type your message here")
334
+ if user_input:
335
+ with st.chat_message("user"):
336
+ st.write(user_input)
337
+
338
+ results = search_query(user_input)
339
+ context = "\n\n".join([chunk for chunk, score in results]) if results else ""
340
+ answer = generate_answer(user_input, context)
341
+
342
+ with st.chat_message("assistant"):
343
+ st.write(answer)
344
+
345
+ conversation_id = save_conversation_to_firestore(
346
+ st.session_state.session_id,
347
+ user_question=user_input,
348
+ assistant_answer=answer
349
+ )
350
+ st.session_state.latest_conversation_id = conversation_id
351
+ st.session_state.conversations.append({
352
+ "user_question": user_input,
353
+ "assistant_answer": answer,
354
+ })
355
+
356
+ # Add feedback buttons
357
+ if "feedback" not in st.session_state.conversations[-1]:
358
+ col1, col2, col3, col4, col5, col6, col7, col8, col9, col10 = st.columns(10)
359
+ col1.button("πŸ‘", key=f"feedback_like_{len(st.session_state.conversations)}",
360
+ on_click=handle_feedback, args=("positive",))
361
+ col2.button("πŸ‘Ž", key=f"feedback_dislike_{len(st.session_state.conversations)}",
362
+ on_click=handle_feedback, args=("negative",))
363
+
364
+ def main():
365
+ """Main application"""
366
+ st.title("Chat with your files")
367
+
368
+ # Sidebar for file upload
369
+ st.sidebar.header("Upload Document")
370
+ uploaded_file = st.sidebar.file_uploader("Upload (.txt, .pdf, .docx)", type=["txt", "pdf", "docx"])
371
+
372
+ if uploaded_file and not st.session_state.get("doc_processed", False):
373
+ process_document(uploaded_file)
374
+
375
+ if "document_text" in st.session_state:
376
+ chat_app()
377
+ else:
378
+ st.info("Please upload and process a document from the sidebar to start chatting.")
379
+
380
+ # Footer
381
+ st.markdown(
382
+ """
383
+ <div style="position: fixed; right: 10px; bottom: 10px; font-size: 12px; z-index: 9999; text-align: right;">
384
+ Made by Danny.<br>
385
+ Your questions, our response as well as your feedback will be saved for evaluation purposes.
386
+ </div>
387
+ """,
388
+ unsafe_allow_html=True
389
+ )
390
+
391
+ if __name__ == "__main__":
392
+ main()