jessica45 commited on
Commit
8953dfc
·
verified ·
1 Parent(s): caab355

initial commit

Browse files
Files changed (6) hide show
  1. app.py +397 -0
  2. chroma_db_utils.py +249 -0
  3. gemini_embedding.py +19 -0
  4. pdf_utils.py +141 -0
  5. query_handler.py +72 -0
  6. requirement.txt +13 -0
app.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # from pdf_utils import extract_text_from_file, split_text
3
+ # from chroma_db_utils import create_chroma_db, load_chroma_collection
4
+ # from query_handler import handle_query
5
+ # import os
6
+ # import re
7
+ # import tempfile
8
+
9
+ # def generate_collection_name(file_path=None):
10
+ # """Generate a valid collection name from a file path."""
11
+ # base_name = os.path.basename(file_path) if file_path else "collection"
12
+ # # Remove file extension
13
+ # base_name = re.sub(r'\..*$', '', base_name)
14
+ # # Replace invalid characters and ensure it starts with a letter
15
+ # base_name = re.sub(r'\W+', '_', base_name)
16
+ # base_name = re.sub(r'^[^a-zA-Z]+', '', base_name)
17
+ # return base_name
18
+
19
+ # def process_uploaded_file(uploaded_file, chroma_db_path):
20
+ # """Process the uploaded file and create/load ChromaDB collection."""
21
+ # # Create a temporary file to store the uploaded content
22
+ # with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
23
+ # tmp_file.write(uploaded_file.getvalue())
24
+ # file_path = tmp_file.name
25
+
26
+ # try:
27
+ # # Generate collection name from original filename
28
+ # collection_name = generate_collection_name(uploaded_file.name)
29
+
30
+ # # Extract and process text
31
+ # file_text = extract_text_from_file(file_path)
32
+ # if file_text is None:
33
+ # return None, "Failed to extract text from the file."
34
+
35
+ # chunked_text = split_text(file_text)
36
+
37
+ # # Try to load existing collection or create new one
38
+ # try:
39
+ # db = load_chroma_collection(collection_name, chroma_db_path)
40
+ # st.success("Loaded existing ChromaDB collection.")
41
+ # except Exception:
42
+ # db = create_chroma_db(chunked_text, collection_name, chroma_db_path)
43
+ # st.success("Created new ChromaDB collection.")
44
+
45
+ # return db, None
46
+
47
+ # except Exception as e:
48
+ # return None, f"Error processing file: {str(e)}"
49
+ # finally:
50
+ # # Clean up temporary file
51
+ # os.unlink(file_path)
52
+
53
+ # def main():
54
+ # st.title("File Question Answering System")
55
+
56
+ # # Sidebar for configuration
57
+ # st.sidebar.header("Configuration")
58
+ # chroma_db_path = st.sidebar.text_input(
59
+ # "ChromaDB Path",
60
+ # value="./chroma_db",
61
+ # help="Directory where ChromaDB collections will be stored"
62
+ # )
63
+
64
+ # # Main content
65
+ # st.write("Upload a file and ask questions about its content!")
66
+
67
+ # # File uploader
68
+ # uploaded_file = st.file_uploader("Upload a file", type=["pdf", "docx", "txt"])
69
+
70
+ # # Session state initialization
71
+ # if 'db' not in st.session_state:
72
+ # st.session_state.db = None
73
+
74
+ # if uploaded_file is not None:
75
+ # # Process file if not already processed
76
+ # if st.session_state.db is None:
77
+ # with st.spinner("Processing PDF file..."):
78
+ # db, error = process_uploaded_file(uploaded_file, chroma_db_path)
79
+ # if error:
80
+ # st.error(error)
81
+ # else:
82
+ # st.session_state.db = db
83
+ # st.success("File processed successfully!")
84
+
85
+ # # Question answering interface
86
+ # st.subheader("Ask a Question")
87
+ # question = st.text_input("Enter your question:")
88
+
89
+ # if question:
90
+ # if st.session_state.db is not None:
91
+ # with st.spinner("Finding answer..."):
92
+ # answer = handle_query(question, st.session_state.db)
93
+ # st.subheader("Answer:")
94
+ # st.write(answer)
95
+ # else:
96
+ # st.error("Please wait for the file to be processed or try uploading again.")
97
+
98
+ # # Clear database button
99
+ # if st.button("Clear Database"):
100
+ # st.session_state.db = None
101
+ # st.success("Database cleared. You can upload a new file.")
102
+
103
+ # if __name__ == "__main__":
104
+ # main()
105
+ import streamlit as st
106
+ import os
107
+ from typing import List
108
+ import time
109
+ from pdf_utils import extract_text_from_file, split_text
110
+ from chroma_db_utils import create_chroma_db
111
+ from query_handler import handle_query
112
+
113
+ def initialize_session_state():
114
+ """Initialize session state variables."""
115
+ if 'messages' not in st.session_state:
116
+ st.session_state.messages = []
117
+ if 'db' not in st.session_state:
118
+ st.session_state.db = None
119
+ if 'chunks' not in st.session_state:
120
+ st.session_state.chunks = []
121
+
122
+ def process_uploaded_file(uploaded_file) -> List[str]:
123
+ """Process the uploaded file and return text chunks."""
124
+ # Create a temporary file to store the uploaded content
125
+ with open(uploaded_file.name, "wb") as f:
126
+ f.write(uploaded_file.getbuffer())
127
+
128
+ try:
129
+ # Extract text from the file
130
+ extracted_text = extract_text_from_file(uploaded_file.name)
131
+ if extracted_text:
132
+ # Split text into chunks
133
+ chunks = split_text(extracted_text)
134
+ return chunks
135
+ else:
136
+ st.error("No text could be extracted from the file.")
137
+ return []
138
+ finally:
139
+ # Clean up temporary file
140
+ if os.path.exists(uploaded_file.name):
141
+ os.remove(uploaded_file.name)
142
+
143
+ def main():
144
+ st.title("📚 Document Q&A System")
145
+
146
+ # Initialize session state
147
+ initialize_session_state()
148
+
149
+ # Sidebar for file upload
150
+ with st.sidebar:
151
+ st.header("Document Upload")
152
+ uploaded_file = st.file_uploader(
153
+ "Upload your document",
154
+ type=['pdf', 'docx', 'txt'],
155
+ help="Supported formats: PDF, DOCX, TXT"
156
+ )
157
+
158
+ if uploaded_file:
159
+ with st.spinner("Processing document..."):
160
+ # Process the uploaded file
161
+ chunks = process_uploaded_file(uploaded_file)
162
+
163
+ if chunks:
164
+ # Create/update the database
165
+ st.session_state.chunks = chunks
166
+ st.session_state.db = create_chroma_db(chunks)
167
+ st.success(f"Document processed! Created {len(chunks)} chunks.")
168
+
169
+ # Add system message to chat history
170
+ if not st.session_state.messages:
171
+ st.session_state.messages.append({
172
+ "role": "system",
173
+ "content": "I've processed your document. You can now ask questions about it!"
174
+ })
175
+
176
+ # Main chat interface
177
+ st.header("💬 Chat")
178
+
179
+ # Display chat messages
180
+ for message in st.session_state.messages:
181
+ with st.chat_message(message["role"]):
182
+ st.write(message["content"])
183
+
184
+ # Chat input
185
+ if prompt := st.chat_input("Ask a question about your document"):
186
+ # Only process if we have a database
187
+ if st.session_state.db is None:
188
+ st.error("Please upload a document first!")
189
+ return
190
+
191
+ # Add user message to chat history
192
+ st.session_state.messages.append({"role": "user", "content": prompt})
193
+
194
+ # Display user message
195
+ with st.chat_message("user"):
196
+ st.write(prompt)
197
+
198
+ # Generate and display assistant response
199
+ with st.chat_message("assistant"):
200
+ with st.spinner("Thinking..."):
201
+ try:
202
+ response = handle_query(prompt, st.session_state.db)
203
+ st.write(response)
204
+
205
+ # Add assistant response to chat history
206
+ st.session_state.messages.append({
207
+ "role": "assistant",
208
+ "content": response
209
+ })
210
+ except Exception as e:
211
+ st.error(f"Error generating response: {str(e)}")
212
+
213
+ # Add a clear chat button
214
+ if st.sidebar.button("Clear Chat"):
215
+ st.session_state.messages = []
216
+ st.experimental_rerun()
217
+
218
+ if __name__ == "__main__":
219
+ main()
220
+
221
+
222
+
223
+
224
+ # import streamlit as st
225
+ # from chromadb.config import Settings
226
+ # import os
227
+ # import chromadb
228
+ # from typing import List
229
+ # import time
230
+ # import google
231
+ # import datetime
232
+ # # from chroma_db_utils import create_chroma_db, get_relevant_passage
233
+ # from query_handler import generate_answer, handle_query
234
+ # from pdf_utils import extract_text_from_file, split_text
235
+ # import logging
236
+
237
+ # # Configure logging
238
+ # logging.basicConfig(level=logging.INFO)
239
+ # logger = logging.getLogger(__name__)
240
+
241
+ # def create_chroma_db(chunks: List[str]):
242
+ # """Create and return an ephemeral ChromaDB collection."""
243
+ # try:
244
+ # # Initialize ChromaDB with ephemeral storage
245
+ # client = chromadb.EphemeralClient()
246
+
247
+ # # Create collection
248
+ # collection_name = f"temp_collection_{int(time.time())}"
249
+ # collection = client.create_collection(name=collection_name)
250
+
251
+ # # Add documents
252
+ # collection.add(
253
+ # documents=chunks,
254
+ # ids=[f"doc_{i}" for i in range(len(chunks))]
255
+ # )
256
+
257
+ # # Verify the data was added
258
+ # verify_count = collection.count()
259
+ # print(f"Verified: Added {verify_count} documents to collection {collection_name}")
260
+
261
+ # # Store both client and collection in session state
262
+ # st.session_state.chroma_client = client
263
+ # return collection
264
+
265
+ # except Exception as e:
266
+ # print(f"Error creating ChromaDB: {str(e)}")
267
+ # return None
268
+
269
+ # def get_relevant_passage(query: str, collection):
270
+ # """Get relevant passages from the collection."""
271
+ # try:
272
+ # # Use the collection directly since it's ephemeral
273
+ # results = collection.query(
274
+ # query_texts=[query],
275
+ # n_results=2
276
+ # )
277
+
278
+ # if results and 'documents' in results:
279
+ # print(f"Found {len(results['documents'])} relevant passages")
280
+ # return results['documents']
281
+ # return None
282
+
283
+ # except Exception as e:
284
+ # print(f"Error in get_relevant_passage: {str(e)}")
285
+ # return None
286
+
287
+ # def initialize_session_state():
288
+ # """Initialize Streamlit session state variables."""
289
+ # if "chat_history" not in st.session_state:
290
+ # st.session_state.chat_history = []
291
+ # if "chroma_collection" not in st.session_state:
292
+ # st.session_state.chroma_collection = None
293
+ # if "chroma_client" not in st.session_state:
294
+ # st.session_state.chroma_client = None
295
+
296
+ # def process_uploaded_file(uploaded_file) -> List[str]:
297
+ # """Process the uploaded file and return text chunks."""
298
+ # temp_file_path = f"/tmp/{uploaded_file.name}"
299
+
300
+ # try:
301
+ # with open(temp_file_path, "wb") as f:
302
+ # f.write(uploaded_file.getbuffer())
303
+
304
+ # # Extract text from the file
305
+ # extracted_text = extract_text_from_file(temp_file_path)
306
+
307
+ # if extracted_text:
308
+ # # Split text into chunks
309
+ # chunks = split_text(extracted_text)
310
+ # return chunks
311
+ # else:
312
+ # st.error("No text could be extracted from the file.")
313
+ # return []
314
+ # finally:
315
+ # if os.path.exists(temp_file_path):
316
+ # os.remove(temp_file_path)
317
+
318
+ # def chat_interface():
319
+ # st.title("Chat with Your Documents 📄💬")
320
+
321
+ # # Debug: Print current state
322
+ # print(f"Current chroma_collection state: {st.session_state.chroma_collection}")
323
+
324
+ # uploaded_files = st.file_uploader(
325
+ # "Upload your files (TXT, PDF)",
326
+ # accept_multiple_files=True,
327
+ # type=['txt', 'pdf']
328
+ # )
329
+
330
+ # if uploaded_files and st.button("Process Files"):
331
+ # with st.spinner("Processing files..."):
332
+ # all_chunks = []
333
+ # for uploaded_file in uploaded_files:
334
+ # chunks = process_uploaded_file(uploaded_file)
335
+ # print(f"Processed {len(chunks)} chunks from {uploaded_file.name}")
336
+ # if chunks:
337
+ # all_chunks.extend(chunks)
338
+
339
+ # if all_chunks:
340
+ # print(f"Creating ChromaDB with {len(all_chunks)} total chunks")
341
+ # # Create ChromaDB collection with all documents
342
+ # db = create_chroma_db(all_chunks)
343
+ # if db:
344
+ # # Verify the collection immediately after creation
345
+ # try:
346
+ # verify_count = db.count()
347
+ # print(f"Verification - Collection size: {verify_count}")
348
+ # # Try a test query
349
+ # test_query = db.query(
350
+ # query_texts=["test verification query"],
351
+ # n_results=1
352
+ # )
353
+ # print("Verification - Query test successful")
354
+
355
+ # st.session_state.chroma_collection = db
356
+ # st.success(f"Files processed successfully! {verify_count} chunks loaded.")
357
+ # except Exception as e:
358
+ # print(f"Verification failed: {str(e)}")
359
+ # st.error("Database verification failed")
360
+ # else:
361
+ # st.error("Failed to create database")
362
+
363
+ # # Query interface
364
+ # if st.session_state.chroma_collection is not None:
365
+ # print("ChromaDB collection found in session state")
366
+ # query = st.text_input("Ask a question about your documents:")
367
+ # if st.button("Send") and query:
368
+ # print(f"Processing query: {query}")
369
+ # with st.spinner("Generating response..."):
370
+ # try:
371
+ # # Verify both client and collection exist
372
+ # if st.session_state.chroma_client is None or st.session_state.chroma_collection is None:
373
+ # st.error("Please upload documents first")
374
+ # return
375
+
376
+ # collection = st.session_state.chroma_collection
377
+ # print(f"Collection name: {collection.name}")
378
+ # print(f"Collection size: {collection.count()}")
379
+
380
+ # relevant_passages = get_relevant_passage(query, collection)
381
+
382
+ # if relevant_passages:
383
+ # response = handle_query(query, relevant_passages)
384
+ # st.session_state.chat_history.append((query, response))
385
+ # else:
386
+ # st.warning("No relevant information found in the documents.")
387
+
388
+ # except Exception as e:
389
+ # print(f"Full error during query processing: {str(e)}")
390
+ # logger.exception("Detailed error trace:") # This will log the full stack trace
391
+ # st.error("Failed to process your question. Please try again.")
392
+ # else:
393
+ # print("No ChromaDB collection in session state")
394
+
395
+ # if __name__ == "__main__":
396
+ # initialize_session_state()
397
+ # chat_interface()
chroma_db_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import chromadb
3
+ # import numpy as np
4
+ # from typing import List, Tuple
5
+ # from gemini_embedding import GeminiEmbeddingFunction
6
+
7
+ # def create_chroma_db(documents: List[str], dataset_name: str, base_path: str = "chroma_db"):
8
+ # """
9
+ # Creates a Chroma database using the provided documents.
10
+ # Automatically generates path and collection name based on dataset_name.
11
+ # """
12
+ # path = os.path.join(base_path, dataset_name)
13
+ # name = f"{dataset_name}_collection"
14
+
15
+ # if not os.path.exists(path):
16
+ # os.makedirs(path)
17
+
18
+ # chroma_client = chromadb.PersistentClient(path=path)
19
+ # db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
20
+
21
+ # for i, doc in enumerate(documents):
22
+ # db.add(documents=[doc], ids=[str(i)])
23
+
24
+ # return db
25
+
26
+ # def load_chroma_collection(dataset_name: str, base_path: str = "chroma_db"):
27
+ # """
28
+ # Loads an existing Chroma collection.
29
+ # """
30
+ # path = os.path.join(base_path, dataset_name)
31
+ # name = f"{dataset_name}_collection"
32
+
33
+ # chroma_client = chromadb.PersistentClient(path=path)
34
+ # return chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
35
+
36
+ # def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
37
+ # """
38
+ # Calculate cosine similarity between two vectors.
39
+ # Returns a value between -1 and 1, where 1 means most similar.
40
+ # """
41
+ # dot_product = np.dot(vec1, vec2)
42
+ # norm1 = np.linalg.norm(vec1)
43
+ # norm2 = np.linalg.norm(vec2)
44
+ # return dot_product / (norm1 * norm2)
45
+
46
+ # def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
47
+ # """
48
+ # Retrieves relevant passages using explicit cosine similarity calculation.
49
+ # """
50
+ # # Get query embedding
51
+ # query_embedding = db._embedding_function([query])[0]
52
+
53
+ # # Get all document embeddings
54
+ # all_docs = db.get(include=['embeddings', 'documents'])
55
+ # doc_embeddings = all_docs['embeddings']
56
+ # documents = all_docs['documents']
57
+
58
+ # # Calculate cosine similarity for each document
59
+ # similarities = []
60
+ # for doc_embedding in doc_embeddings:
61
+ # similarity = cosine_similarity(query_embedding, doc_embedding)
62
+ # similarities.append(similarity)
63
+
64
+ # # Sort documents by similarity
65
+ # doc_similarities = list(zip(documents, similarities))
66
+ # doc_similarities.sort(key=lambda x: x[1], reverse=True)
67
+
68
+ # # Take top n results
69
+ # top_results = doc_similarities[:n_results]
70
+
71
+ # # Print results for debugging
72
+ # print(f"Number of relevant passages retrieved: {len(top_results)}")
73
+ # for i, (doc, similarity) in enumerate(top_results):
74
+ # print(f"Passage {i+1} (Cosine Similarity: {similarity:.4f}): {doc[:100]}...")
75
+
76
+ # # Return just the documents
77
+ # return [doc for doc, _ in top_results]
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+ # in memory
86
+
87
+
88
+ # import chromadb
89
+ # from typing import List
90
+ # from gemini_embedding import GeminiEmbeddingFunction # Ensure this is correctly implemented
91
+ # import time
92
+ # from chromadb.config import Settings
93
+
94
+ # def create_chroma_db(chunks: List[str]):
95
+ # """Create and return an in-memory ChromaDB collection."""
96
+ # try:
97
+ # # Initialize in-memory ChromaDB with current recommended configuration
98
+ # client = chromadb.Client()
99
+
100
+ # # Create collection with unique name to avoid conflicts
101
+ # collection_name = f"temp_collection_{int(time.time())}"
102
+ # collection = client.create_collection(name=collection_name)
103
+
104
+ # # Add documents with unique IDs
105
+ # collection.add(
106
+ # documents=chunks,
107
+ # ids=[f"doc_{i}" for i in range(len(chunks))]
108
+ # )
109
+
110
+ # # Verify the data was added
111
+ # verify_count = collection.count()
112
+ # print(f"Verified: Added {verify_count} documents to collection {collection_name}")
113
+
114
+ # # Test query to ensure collection is working
115
+ # test_results = collection.query(
116
+ # query_texts=["test"],
117
+ # n_results=1
118
+ # )
119
+ # print("Verified: Collection is queryable")
120
+
121
+ # return collection
122
+
123
+ # except Exception as e:
124
+ # print(f"Error creating ChromaDB: {str(e)}")
125
+ # return None
126
+
127
+ # def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
128
+ # """
129
+ # Retrieves relevant passages using ChromaDB's similarity search.
130
+ # """
131
+ # try:
132
+ # if db is None:
133
+ # print("Database not initialized")
134
+ # return []
135
+
136
+ # # Verify collection has documents
137
+ # count = db.count()
138
+ # if count == 0:
139
+ # print("Collection is empty")
140
+ # return []
141
+
142
+ # # Query the database
143
+ # results = db.query(
144
+ # query_texts=[query],
145
+ # n_results=min(n_results, count) # Ensure we don't request more than we have
146
+ # )
147
+
148
+ # # Ensure results exist
149
+ # if not results["documents"]:
150
+ # print("No relevant passages found.")
151
+ # return []
152
+
153
+ # documents = results["documents"][0] # First result batch
154
+ # distances = results["distances"][0] # Corresponding distances
155
+
156
+ # # Debug output
157
+ # print(f"Number of relevant passages retrieved: {len(documents)}")
158
+ # for i, (doc, distance) in enumerate(zip(documents, distances)):
159
+ # similarity = 1 - distance # Convert distance to similarity
160
+ # print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
161
+
162
+ # return documents
163
+ # except Exception as e:
164
+ # print(f"Error in get_relevant_passage: {str(e)}")
165
+ # return []
166
+
167
+
168
+ import chromadb
169
+ from chromadb.config import Settings
170
+ from typing import List
171
+ import os
172
+ from gemini_embedding import GeminiEmbeddingFunction
173
+ import datetime
174
+ embedding_function = GeminiEmbeddingFunction()
175
+
176
+ def create_chroma_db(documents: List[str]):
177
+ """
178
+ Creates a persistent Chroma database using the provided documents.
179
+ """
180
+ # Create a persistent directory for ChromaDB
181
+ persist_directory = "chroma_db"
182
+ os.makedirs(persist_directory, exist_ok=True)
183
+
184
+ # Initialize the client with persistence
185
+ chroma_client = chromadb.PersistentClient(
186
+ path=persist_directory,
187
+ )
188
+
189
+ # Get or create collection
190
+ try:
191
+ # Try to get existing collection
192
+ db = chroma_client.get_collection(
193
+ name="document_collection",
194
+ embedding_function=embedding_function
195
+ )
196
+ # Clear existing documents
197
+ db.delete(db.get()["ids"])
198
+ except:
199
+ # Create new collection if it doesn't exist
200
+ db = chroma_client.create_collection(
201
+ name="document_collection",
202
+ embedding_function=embedding_function
203
+ )
204
+
205
+ # Add documents in batches to avoid memory issues
206
+ batch_size = 20
207
+ for i in range(0, len(documents), batch_size):
208
+ batch = documents[i:i + batch_size]
209
+ db.add(
210
+ documents=batch,
211
+ ids=[f"doc_{j}" for j in range(i, i + len(batch))]
212
+ )
213
+
214
+ return db
215
+
216
+ def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
217
+ start_time = datetime.datetime.now()
218
+ print(f"{start_time}: Starting ChromaDB query for question: {query[:50]}...") # Log query start
219
+
220
+ try:
221
+ results = db.query(
222
+ query_texts=[query],
223
+ n_results=min(n_results, db.count()),
224
+ include=['documents', 'distances']
225
+ )
226
+ end_time = datetime.datetime.now()
227
+ print(f"{end_time}: ChromaDB query finished. Time taken: {end_time - start_time}") # Log the time taken
228
+ # ... (rest of your get_relevant_passage function remains the same)
229
+
230
+ # Ensure results exist and contain at least one document
231
+ if not results or 'documents' not in results or not results['documents'] or not results['documents'][0]:
232
+ print("No relevant passages found.")
233
+ return []
234
+
235
+ # Extract valid results
236
+ documents = results['documents'][0] # List of retrieved documents
237
+ distances = results['distances'][0] # Corresponding similarity scores
238
+
239
+ # Debugging output
240
+ print(f"Number of relevant passages retrieved: {len(documents)}")
241
+ for i, (doc, distance) in enumerate(zip(documents, distances)):
242
+ similarity = 1 - distance # Convert distance to similarity score
243
+ print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
244
+
245
+ return documents # Return only valid results
246
+ except Exception as e:
247
+ print(f"Error in get_relevant_passage: {str(e)}")
248
+ return []
249
+
gemini_embedding.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import google.generativeai as genai
3
+ from chromadb.api.types import Documents, Embeddings
4
+ from chromadb import EmbeddingFunction
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+ gemini_api_key = os.environ["GEMINI_API_KEY"]
9
+
10
+ class GeminiEmbeddingFunction(EmbeddingFunction):
11
+ """
12
+ Custom embedding function using Gemini AI API.
13
+ """
14
+ def __call__(self, input: Documents) -> Embeddings:
15
+ if not gemini_api_key:
16
+ raise ValueError("Gemini API Key not provided. Please set GEMINI_API_KEY as an environment variable.")
17
+ genai.configure(api_key=gemini_api_key)
18
+ model = "models/text-embedding-004"
19
+ return genai.embed_content(model=model, content=input, task_type="retrieval_document")["embedding"]
pdf_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import pdfplumber
3
+ from typing import List, Optional
4
+ import textract
5
+ from docx import Document
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ import os
8
+ import logging
9
+ import warnings
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def clean_text(text: str) -> str:
16
+ """Clean extracted text by removing extra whitespace and invalid characters."""
17
+ text = re.sub(r'\s+', ' ', text) # Remove multiple spaces
18
+ text = ''.join(char for char in text if char.isprintable() or char == '\n') # Remove non-printable characters
19
+ text = re.sub(r'\n\s*\n', '\n\n', text) # Remove multiple newlines
20
+ return text.strip()
21
+
22
+ def extract_text_from_pdf(pdf_path: str) -> Optional[str]:
23
+ """
24
+ Extract text from PDF using pdfplumber.
25
+ """
26
+ extracted_text = []
27
+ try:
28
+ with pdfplumber.open(pdf_path) as pdf:
29
+ for page_num, page in enumerate(pdf.pages, 1):
30
+ try:
31
+ page_text = page.extract_text()
32
+ if page_text:
33
+ extracted_text.append(page_text)
34
+ else:
35
+ logger.warning(f"No text extracted from page {page_num}")
36
+ except Exception as e:
37
+ logger.error(f"Error extracting text from page {page_num}: {e}")
38
+ continue
39
+
40
+ if not extracted_text:
41
+ logger.warning("No text was extracted from any page of the PDF")
42
+ return None
43
+
44
+ return clean_text('\n'.join(extracted_text))
45
+ except Exception as e:
46
+ logger.error(f"Failed to process PDF {pdf_path}: {e}")
47
+ return None
48
+
49
+ def extract_text_from_docx(docx_path: str) -> Optional[str]:
50
+ """
51
+ Extract text from DOCX with enhanced error handling.
52
+ """
53
+ try:
54
+ doc = Document(docx_path)
55
+ text = '\n'.join(para.text for para in doc.paragraphs if para.text.strip())
56
+ return clean_text(text) if text else None
57
+ except Exception as e:
58
+ logger.error(f"Failed to process DOCX {docx_path}: {e}")
59
+ return None
60
+
61
+ import tempfile
62
+
63
+ def extract_text_from_file(uploaded_file) -> Optional[str]:
64
+ """
65
+ Extract text from various file types with enhanced error handling and logging.
66
+ If file is uploaded as file-like object, save it temporarily.
67
+ """
68
+ if isinstance(uploaded_file, str): # Assuming file_path is a string for direct file handling
69
+ file_path = uploaded_file
70
+ else: # Handle file-like objects (e.g., uploaded files)
71
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
72
+ temp_file.write(uploaded_file.read()) # Save file contents temporarily
73
+ file_path = temp_file.name # Temporary file path
74
+
75
+ if not os.path.exists(file_path):
76
+ logger.error(f"File not found: {file_path}")
77
+ return None
78
+
79
+ _, file_extension = os.path.splitext(file_path)
80
+ file_extension = file_extension.lower()
81
+
82
+ try:
83
+ if file_extension == ".pdf":
84
+ text = extract_text_from_pdf(file_path)
85
+ elif file_extension == ".docx":
86
+ text = extract_text_from_docx(file_path)
87
+ elif file_extension == ".txt":
88
+ try:
89
+ with open(file_path, "r", encoding="utf-8") as file:
90
+ text = clean_text(file.read())
91
+ except UnicodeDecodeError:
92
+ with open(file_path, "r", encoding="latin-1") as file:
93
+ text = clean_text(file.read())
94
+ else:
95
+ text = clean_text(textract.process(file_path).decode("utf-8"))
96
+
97
+ if not text:
98
+ logger.warning(f"No text content extracted from {file_path}")
99
+ return None
100
+
101
+ return text
102
+
103
+ except Exception as e:
104
+ logger.error(f"Error extracting text from {file_path}: {e}")
105
+ return None
106
+
107
+
108
+ def split_text(text: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> List[str]:
109
+ """
110
+ Split text into chunks with improved handling and validation.
111
+ """
112
+ if not text:
113
+ logger.warning("Empty text provided for splitting")
114
+ return []
115
+
116
+ try:
117
+ text_splitter = RecursiveCharacterTextSplitter(
118
+ chunk_size=chunk_size,
119
+ chunk_overlap=chunk_overlap,
120
+ length_function=len,
121
+ is_separator_regex=False
122
+ )
123
+
124
+ splits = text_splitter.split_text(text)
125
+
126
+ logger.info(f"Split text into {len(splits)} chunks")
127
+
128
+ return splits
129
+
130
+ except Exception as e:
131
+ logger.error(f"Error splitting text: {e}")
132
+ return []
133
+
134
+ # Example usage
135
+ if __name__ == "__main__":
136
+ sample_file = "/Users/jessicawin/Downloads/github-recovery-codes.txt"
137
+ if os.path.exists(sample_file):
138
+ file_text = extract_text_from_file(sample_file)
139
+ if file_text:
140
+ chunks = split_text(file_text)
141
+ print(f"Successfully processed file into {len(chunks)} chunks")
query_handler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import google.generativeai as genai
3
+ from chroma_db_utils import get_relevant_passage
4
+ import time
5
+ import datetime
6
+ import google.api_core.exceptions
7
+
8
+ # Constants
9
+ MAX_RETRIES = 3
10
+ RETRY_DELAY = 1 # Initial delay in seconds
11
+ MODEL_NAME = "gemini-1.5-flash"
12
+ REQUESTS_PER_MINUTE = 3 # Free tier limit
13
+ REQUEST_INTERVAL = 60 / REQUESTS_PER_MINUTE # Ensures we stay within the rate limit
14
+
15
+ def make_rag_prompt(query: str, relevant_passage: str) -> str:
16
+ """
17
+ Creates a prompt for the RAG model.
18
+ """
19
+ escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
20
+ prompt = f'''
21
+ You are a helpful and informative bot that answers questions using the REFERENCE TEXT below.
22
+ If the REFERENCE TEXT is irrelevant to the question, say "I cannot answer this question based on the provided information."
23
+
24
+ QUESTION: {query}
25
+
26
+ REFERENCE TEXT:
27
+ {escaped}
28
+
29
+ ANSWER:
30
+ '''
31
+ return prompt
32
+
33
+ def generate_answer(prompt: str) -> str:
34
+ """
35
+ Calls the Gemini API with retries and rate limiting.
36
+ """
37
+ gemini_api_key = os.getenv("GEMINI_API_KEY")
38
+ if not gemini_api_key:
39
+ raise ValueError("Gemini API Key not provided.")
40
+
41
+ genai.configure(api_key=gemini_api_key)
42
+ model = genai.GenerativeModel(MODEL_NAME)
43
+
44
+ for attempt in range(MAX_RETRIES):
45
+ start_time = datetime.datetime.now()
46
+ print(f"{start_time}: Making Gemini API request (attempt {attempt + 1}/{MAX_RETRIES})...")
47
+ try:
48
+ response = model.generate_content(prompt)
49
+ end_time = datetime.datetime.now()
50
+ print(f"{end_time}: Gemini API request successful. Time taken: {end_time - start_time}")
51
+
52
+ # Enforce rate limiting
53
+ time.sleep(REQUEST_INTERVAL)
54
+ return response.text
55
+ except google.api_core.exceptions.ResourceExhausted as e:
56
+ if e.code == 429: # Too Many Requests
57
+ delay = RETRY_DELAY * (2 ** attempt) # Exponential backoff
58
+ print(f"Rate limit hit. Retrying in {delay} seconds (attempt {attempt + 1}/{MAX_RETRIES})...")
59
+ time.sleep(delay)
60
+ else:
61
+ raise # Re-raise other exceptions
62
+
63
+ raise Exception("Max retries exceeded for Gemini API request.")
64
+
65
+ def handle_query(query: str, db, n_results: int = 5) -> str:
66
+ """
67
+ Handles a user query by retrieving relevant passages and generating an answer.
68
+ """
69
+ relevant_passages = get_relevant_passage(query, db, n_results)
70
+ relevant_passage_str = " ".join(relevant_passages)
71
+ prompt = make_rag_prompt(query, relevant_passage=relevant_passage_str)
72
+ return generate_answer(prompt)
requirement.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ google-generativeai>=0.3.0
2
+ chromadb
3
+ pdfplumber
4
+ python-docx>=0.8.11
5
+ textract>=1.6.5
6
+ langchain>=0.1.0
7
+ chromadb>=0.4.0
8
+ numpy>=1.21.0
9
+ python-dotenv>=0.19.0
10
+ streamlit>=1.18.0
11
+ typing>=3.7.4
12
+ warnings>=0.1.0
13
+ logging>=0.5.0