initial commit
Browse files- app.py +397 -0
- chroma_db_utils.py +249 -0
- gemini_embedding.py +19 -0
- pdf_utils.py +141 -0
- query_handler.py +72 -0
- requirement.txt +13 -0
app.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import streamlit as st
|
2 |
+
# from pdf_utils import extract_text_from_file, split_text
|
3 |
+
# from chroma_db_utils import create_chroma_db, load_chroma_collection
|
4 |
+
# from query_handler import handle_query
|
5 |
+
# import os
|
6 |
+
# import re
|
7 |
+
# import tempfile
|
8 |
+
|
9 |
+
# def generate_collection_name(file_path=None):
|
10 |
+
# """Generate a valid collection name from a file path."""
|
11 |
+
# base_name = os.path.basename(file_path) if file_path else "collection"
|
12 |
+
# # Remove file extension
|
13 |
+
# base_name = re.sub(r'\..*$', '', base_name)
|
14 |
+
# # Replace invalid characters and ensure it starts with a letter
|
15 |
+
# base_name = re.sub(r'\W+', '_', base_name)
|
16 |
+
# base_name = re.sub(r'^[^a-zA-Z]+', '', base_name)
|
17 |
+
# return base_name
|
18 |
+
|
19 |
+
# def process_uploaded_file(uploaded_file, chroma_db_path):
|
20 |
+
# """Process the uploaded file and create/load ChromaDB collection."""
|
21 |
+
# # Create a temporary file to store the uploaded content
|
22 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
|
23 |
+
# tmp_file.write(uploaded_file.getvalue())
|
24 |
+
# file_path = tmp_file.name
|
25 |
+
|
26 |
+
# try:
|
27 |
+
# # Generate collection name from original filename
|
28 |
+
# collection_name = generate_collection_name(uploaded_file.name)
|
29 |
+
|
30 |
+
# # Extract and process text
|
31 |
+
# file_text = extract_text_from_file(file_path)
|
32 |
+
# if file_text is None:
|
33 |
+
# return None, "Failed to extract text from the file."
|
34 |
+
|
35 |
+
# chunked_text = split_text(file_text)
|
36 |
+
|
37 |
+
# # Try to load existing collection or create new one
|
38 |
+
# try:
|
39 |
+
# db = load_chroma_collection(collection_name, chroma_db_path)
|
40 |
+
# st.success("Loaded existing ChromaDB collection.")
|
41 |
+
# except Exception:
|
42 |
+
# db = create_chroma_db(chunked_text, collection_name, chroma_db_path)
|
43 |
+
# st.success("Created new ChromaDB collection.")
|
44 |
+
|
45 |
+
# return db, None
|
46 |
+
|
47 |
+
# except Exception as e:
|
48 |
+
# return None, f"Error processing file: {str(e)}"
|
49 |
+
# finally:
|
50 |
+
# # Clean up temporary file
|
51 |
+
# os.unlink(file_path)
|
52 |
+
|
53 |
+
# def main():
|
54 |
+
# st.title("File Question Answering System")
|
55 |
+
|
56 |
+
# # Sidebar for configuration
|
57 |
+
# st.sidebar.header("Configuration")
|
58 |
+
# chroma_db_path = st.sidebar.text_input(
|
59 |
+
# "ChromaDB Path",
|
60 |
+
# value="./chroma_db",
|
61 |
+
# help="Directory where ChromaDB collections will be stored"
|
62 |
+
# )
|
63 |
+
|
64 |
+
# # Main content
|
65 |
+
# st.write("Upload a file and ask questions about its content!")
|
66 |
+
|
67 |
+
# # File uploader
|
68 |
+
# uploaded_file = st.file_uploader("Upload a file", type=["pdf", "docx", "txt"])
|
69 |
+
|
70 |
+
# # Session state initialization
|
71 |
+
# if 'db' not in st.session_state:
|
72 |
+
# st.session_state.db = None
|
73 |
+
|
74 |
+
# if uploaded_file is not None:
|
75 |
+
# # Process file if not already processed
|
76 |
+
# if st.session_state.db is None:
|
77 |
+
# with st.spinner("Processing PDF file..."):
|
78 |
+
# db, error = process_uploaded_file(uploaded_file, chroma_db_path)
|
79 |
+
# if error:
|
80 |
+
# st.error(error)
|
81 |
+
# else:
|
82 |
+
# st.session_state.db = db
|
83 |
+
# st.success("File processed successfully!")
|
84 |
+
|
85 |
+
# # Question answering interface
|
86 |
+
# st.subheader("Ask a Question")
|
87 |
+
# question = st.text_input("Enter your question:")
|
88 |
+
|
89 |
+
# if question:
|
90 |
+
# if st.session_state.db is not None:
|
91 |
+
# with st.spinner("Finding answer..."):
|
92 |
+
# answer = handle_query(question, st.session_state.db)
|
93 |
+
# st.subheader("Answer:")
|
94 |
+
# st.write(answer)
|
95 |
+
# else:
|
96 |
+
# st.error("Please wait for the file to be processed or try uploading again.")
|
97 |
+
|
98 |
+
# # Clear database button
|
99 |
+
# if st.button("Clear Database"):
|
100 |
+
# st.session_state.db = None
|
101 |
+
# st.success("Database cleared. You can upload a new file.")
|
102 |
+
|
103 |
+
# if __name__ == "__main__":
|
104 |
+
# main()
|
105 |
+
import streamlit as st
|
106 |
+
import os
|
107 |
+
from typing import List
|
108 |
+
import time
|
109 |
+
from pdf_utils import extract_text_from_file, split_text
|
110 |
+
from chroma_db_utils import create_chroma_db
|
111 |
+
from query_handler import handle_query
|
112 |
+
|
113 |
+
def initialize_session_state():
|
114 |
+
"""Initialize session state variables."""
|
115 |
+
if 'messages' not in st.session_state:
|
116 |
+
st.session_state.messages = []
|
117 |
+
if 'db' not in st.session_state:
|
118 |
+
st.session_state.db = None
|
119 |
+
if 'chunks' not in st.session_state:
|
120 |
+
st.session_state.chunks = []
|
121 |
+
|
122 |
+
def process_uploaded_file(uploaded_file) -> List[str]:
|
123 |
+
"""Process the uploaded file and return text chunks."""
|
124 |
+
# Create a temporary file to store the uploaded content
|
125 |
+
with open(uploaded_file.name, "wb") as f:
|
126 |
+
f.write(uploaded_file.getbuffer())
|
127 |
+
|
128 |
+
try:
|
129 |
+
# Extract text from the file
|
130 |
+
extracted_text = extract_text_from_file(uploaded_file.name)
|
131 |
+
if extracted_text:
|
132 |
+
# Split text into chunks
|
133 |
+
chunks = split_text(extracted_text)
|
134 |
+
return chunks
|
135 |
+
else:
|
136 |
+
st.error("No text could be extracted from the file.")
|
137 |
+
return []
|
138 |
+
finally:
|
139 |
+
# Clean up temporary file
|
140 |
+
if os.path.exists(uploaded_file.name):
|
141 |
+
os.remove(uploaded_file.name)
|
142 |
+
|
143 |
+
def main():
|
144 |
+
st.title("📚 Document Q&A System")
|
145 |
+
|
146 |
+
# Initialize session state
|
147 |
+
initialize_session_state()
|
148 |
+
|
149 |
+
# Sidebar for file upload
|
150 |
+
with st.sidebar:
|
151 |
+
st.header("Document Upload")
|
152 |
+
uploaded_file = st.file_uploader(
|
153 |
+
"Upload your document",
|
154 |
+
type=['pdf', 'docx', 'txt'],
|
155 |
+
help="Supported formats: PDF, DOCX, TXT"
|
156 |
+
)
|
157 |
+
|
158 |
+
if uploaded_file:
|
159 |
+
with st.spinner("Processing document..."):
|
160 |
+
# Process the uploaded file
|
161 |
+
chunks = process_uploaded_file(uploaded_file)
|
162 |
+
|
163 |
+
if chunks:
|
164 |
+
# Create/update the database
|
165 |
+
st.session_state.chunks = chunks
|
166 |
+
st.session_state.db = create_chroma_db(chunks)
|
167 |
+
st.success(f"Document processed! Created {len(chunks)} chunks.")
|
168 |
+
|
169 |
+
# Add system message to chat history
|
170 |
+
if not st.session_state.messages:
|
171 |
+
st.session_state.messages.append({
|
172 |
+
"role": "system",
|
173 |
+
"content": "I've processed your document. You can now ask questions about it!"
|
174 |
+
})
|
175 |
+
|
176 |
+
# Main chat interface
|
177 |
+
st.header("💬 Chat")
|
178 |
+
|
179 |
+
# Display chat messages
|
180 |
+
for message in st.session_state.messages:
|
181 |
+
with st.chat_message(message["role"]):
|
182 |
+
st.write(message["content"])
|
183 |
+
|
184 |
+
# Chat input
|
185 |
+
if prompt := st.chat_input("Ask a question about your document"):
|
186 |
+
# Only process if we have a database
|
187 |
+
if st.session_state.db is None:
|
188 |
+
st.error("Please upload a document first!")
|
189 |
+
return
|
190 |
+
|
191 |
+
# Add user message to chat history
|
192 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
193 |
+
|
194 |
+
# Display user message
|
195 |
+
with st.chat_message("user"):
|
196 |
+
st.write(prompt)
|
197 |
+
|
198 |
+
# Generate and display assistant response
|
199 |
+
with st.chat_message("assistant"):
|
200 |
+
with st.spinner("Thinking..."):
|
201 |
+
try:
|
202 |
+
response = handle_query(prompt, st.session_state.db)
|
203 |
+
st.write(response)
|
204 |
+
|
205 |
+
# Add assistant response to chat history
|
206 |
+
st.session_state.messages.append({
|
207 |
+
"role": "assistant",
|
208 |
+
"content": response
|
209 |
+
})
|
210 |
+
except Exception as e:
|
211 |
+
st.error(f"Error generating response: {str(e)}")
|
212 |
+
|
213 |
+
# Add a clear chat button
|
214 |
+
if st.sidebar.button("Clear Chat"):
|
215 |
+
st.session_state.messages = []
|
216 |
+
st.experimental_rerun()
|
217 |
+
|
218 |
+
if __name__ == "__main__":
|
219 |
+
main()
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
# import streamlit as st
|
225 |
+
# from chromadb.config import Settings
|
226 |
+
# import os
|
227 |
+
# import chromadb
|
228 |
+
# from typing import List
|
229 |
+
# import time
|
230 |
+
# import google
|
231 |
+
# import datetime
|
232 |
+
# # from chroma_db_utils import create_chroma_db, get_relevant_passage
|
233 |
+
# from query_handler import generate_answer, handle_query
|
234 |
+
# from pdf_utils import extract_text_from_file, split_text
|
235 |
+
# import logging
|
236 |
+
|
237 |
+
# # Configure logging
|
238 |
+
# logging.basicConfig(level=logging.INFO)
|
239 |
+
# logger = logging.getLogger(__name__)
|
240 |
+
|
241 |
+
# def create_chroma_db(chunks: List[str]):
|
242 |
+
# """Create and return an ephemeral ChromaDB collection."""
|
243 |
+
# try:
|
244 |
+
# # Initialize ChromaDB with ephemeral storage
|
245 |
+
# client = chromadb.EphemeralClient()
|
246 |
+
|
247 |
+
# # Create collection
|
248 |
+
# collection_name = f"temp_collection_{int(time.time())}"
|
249 |
+
# collection = client.create_collection(name=collection_name)
|
250 |
+
|
251 |
+
# # Add documents
|
252 |
+
# collection.add(
|
253 |
+
# documents=chunks,
|
254 |
+
# ids=[f"doc_{i}" for i in range(len(chunks))]
|
255 |
+
# )
|
256 |
+
|
257 |
+
# # Verify the data was added
|
258 |
+
# verify_count = collection.count()
|
259 |
+
# print(f"Verified: Added {verify_count} documents to collection {collection_name}")
|
260 |
+
|
261 |
+
# # Store both client and collection in session state
|
262 |
+
# st.session_state.chroma_client = client
|
263 |
+
# return collection
|
264 |
+
|
265 |
+
# except Exception as e:
|
266 |
+
# print(f"Error creating ChromaDB: {str(e)}")
|
267 |
+
# return None
|
268 |
+
|
269 |
+
# def get_relevant_passage(query: str, collection):
|
270 |
+
# """Get relevant passages from the collection."""
|
271 |
+
# try:
|
272 |
+
# # Use the collection directly since it's ephemeral
|
273 |
+
# results = collection.query(
|
274 |
+
# query_texts=[query],
|
275 |
+
# n_results=2
|
276 |
+
# )
|
277 |
+
|
278 |
+
# if results and 'documents' in results:
|
279 |
+
# print(f"Found {len(results['documents'])} relevant passages")
|
280 |
+
# return results['documents']
|
281 |
+
# return None
|
282 |
+
|
283 |
+
# except Exception as e:
|
284 |
+
# print(f"Error in get_relevant_passage: {str(e)}")
|
285 |
+
# return None
|
286 |
+
|
287 |
+
# def initialize_session_state():
|
288 |
+
# """Initialize Streamlit session state variables."""
|
289 |
+
# if "chat_history" not in st.session_state:
|
290 |
+
# st.session_state.chat_history = []
|
291 |
+
# if "chroma_collection" not in st.session_state:
|
292 |
+
# st.session_state.chroma_collection = None
|
293 |
+
# if "chroma_client" not in st.session_state:
|
294 |
+
# st.session_state.chroma_client = None
|
295 |
+
|
296 |
+
# def process_uploaded_file(uploaded_file) -> List[str]:
|
297 |
+
# """Process the uploaded file and return text chunks."""
|
298 |
+
# temp_file_path = f"/tmp/{uploaded_file.name}"
|
299 |
+
|
300 |
+
# try:
|
301 |
+
# with open(temp_file_path, "wb") as f:
|
302 |
+
# f.write(uploaded_file.getbuffer())
|
303 |
+
|
304 |
+
# # Extract text from the file
|
305 |
+
# extracted_text = extract_text_from_file(temp_file_path)
|
306 |
+
|
307 |
+
# if extracted_text:
|
308 |
+
# # Split text into chunks
|
309 |
+
# chunks = split_text(extracted_text)
|
310 |
+
# return chunks
|
311 |
+
# else:
|
312 |
+
# st.error("No text could be extracted from the file.")
|
313 |
+
# return []
|
314 |
+
# finally:
|
315 |
+
# if os.path.exists(temp_file_path):
|
316 |
+
# os.remove(temp_file_path)
|
317 |
+
|
318 |
+
# def chat_interface():
|
319 |
+
# st.title("Chat with Your Documents 📄💬")
|
320 |
+
|
321 |
+
# # Debug: Print current state
|
322 |
+
# print(f"Current chroma_collection state: {st.session_state.chroma_collection}")
|
323 |
+
|
324 |
+
# uploaded_files = st.file_uploader(
|
325 |
+
# "Upload your files (TXT, PDF)",
|
326 |
+
# accept_multiple_files=True,
|
327 |
+
# type=['txt', 'pdf']
|
328 |
+
# )
|
329 |
+
|
330 |
+
# if uploaded_files and st.button("Process Files"):
|
331 |
+
# with st.spinner("Processing files..."):
|
332 |
+
# all_chunks = []
|
333 |
+
# for uploaded_file in uploaded_files:
|
334 |
+
# chunks = process_uploaded_file(uploaded_file)
|
335 |
+
# print(f"Processed {len(chunks)} chunks from {uploaded_file.name}")
|
336 |
+
# if chunks:
|
337 |
+
# all_chunks.extend(chunks)
|
338 |
+
|
339 |
+
# if all_chunks:
|
340 |
+
# print(f"Creating ChromaDB with {len(all_chunks)} total chunks")
|
341 |
+
# # Create ChromaDB collection with all documents
|
342 |
+
# db = create_chroma_db(all_chunks)
|
343 |
+
# if db:
|
344 |
+
# # Verify the collection immediately after creation
|
345 |
+
# try:
|
346 |
+
# verify_count = db.count()
|
347 |
+
# print(f"Verification - Collection size: {verify_count}")
|
348 |
+
# # Try a test query
|
349 |
+
# test_query = db.query(
|
350 |
+
# query_texts=["test verification query"],
|
351 |
+
# n_results=1
|
352 |
+
# )
|
353 |
+
# print("Verification - Query test successful")
|
354 |
+
|
355 |
+
# st.session_state.chroma_collection = db
|
356 |
+
# st.success(f"Files processed successfully! {verify_count} chunks loaded.")
|
357 |
+
# except Exception as e:
|
358 |
+
# print(f"Verification failed: {str(e)}")
|
359 |
+
# st.error("Database verification failed")
|
360 |
+
# else:
|
361 |
+
# st.error("Failed to create database")
|
362 |
+
|
363 |
+
# # Query interface
|
364 |
+
# if st.session_state.chroma_collection is not None:
|
365 |
+
# print("ChromaDB collection found in session state")
|
366 |
+
# query = st.text_input("Ask a question about your documents:")
|
367 |
+
# if st.button("Send") and query:
|
368 |
+
# print(f"Processing query: {query}")
|
369 |
+
# with st.spinner("Generating response..."):
|
370 |
+
# try:
|
371 |
+
# # Verify both client and collection exist
|
372 |
+
# if st.session_state.chroma_client is None or st.session_state.chroma_collection is None:
|
373 |
+
# st.error("Please upload documents first")
|
374 |
+
# return
|
375 |
+
|
376 |
+
# collection = st.session_state.chroma_collection
|
377 |
+
# print(f"Collection name: {collection.name}")
|
378 |
+
# print(f"Collection size: {collection.count()}")
|
379 |
+
|
380 |
+
# relevant_passages = get_relevant_passage(query, collection)
|
381 |
+
|
382 |
+
# if relevant_passages:
|
383 |
+
# response = handle_query(query, relevant_passages)
|
384 |
+
# st.session_state.chat_history.append((query, response))
|
385 |
+
# else:
|
386 |
+
# st.warning("No relevant information found in the documents.")
|
387 |
+
|
388 |
+
# except Exception as e:
|
389 |
+
# print(f"Full error during query processing: {str(e)}")
|
390 |
+
# logger.exception("Detailed error trace:") # This will log the full stack trace
|
391 |
+
# st.error("Failed to process your question. Please try again.")
|
392 |
+
# else:
|
393 |
+
# print("No ChromaDB collection in session state")
|
394 |
+
|
395 |
+
# if __name__ == "__main__":
|
396 |
+
# initialize_session_state()
|
397 |
+
# chat_interface()
|
chroma_db_utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import os
|
2 |
+
# import chromadb
|
3 |
+
# import numpy as np
|
4 |
+
# from typing import List, Tuple
|
5 |
+
# from gemini_embedding import GeminiEmbeddingFunction
|
6 |
+
|
7 |
+
# def create_chroma_db(documents: List[str], dataset_name: str, base_path: str = "chroma_db"):
|
8 |
+
# """
|
9 |
+
# Creates a Chroma database using the provided documents.
|
10 |
+
# Automatically generates path and collection name based on dataset_name.
|
11 |
+
# """
|
12 |
+
# path = os.path.join(base_path, dataset_name)
|
13 |
+
# name = f"{dataset_name}_collection"
|
14 |
+
|
15 |
+
# if not os.path.exists(path):
|
16 |
+
# os.makedirs(path)
|
17 |
+
|
18 |
+
# chroma_client = chromadb.PersistentClient(path=path)
|
19 |
+
# db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
20 |
+
|
21 |
+
# for i, doc in enumerate(documents):
|
22 |
+
# db.add(documents=[doc], ids=[str(i)])
|
23 |
+
|
24 |
+
# return db
|
25 |
+
|
26 |
+
# def load_chroma_collection(dataset_name: str, base_path: str = "chroma_db"):
|
27 |
+
# """
|
28 |
+
# Loads an existing Chroma collection.
|
29 |
+
# """
|
30 |
+
# path = os.path.join(base_path, dataset_name)
|
31 |
+
# name = f"{dataset_name}_collection"
|
32 |
+
|
33 |
+
# chroma_client = chromadb.PersistentClient(path=path)
|
34 |
+
# return chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFunction())
|
35 |
+
|
36 |
+
# def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
|
37 |
+
# """
|
38 |
+
# Calculate cosine similarity between two vectors.
|
39 |
+
# Returns a value between -1 and 1, where 1 means most similar.
|
40 |
+
# """
|
41 |
+
# dot_product = np.dot(vec1, vec2)
|
42 |
+
# norm1 = np.linalg.norm(vec1)
|
43 |
+
# norm2 = np.linalg.norm(vec2)
|
44 |
+
# return dot_product / (norm1 * norm2)
|
45 |
+
|
46 |
+
# def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
47 |
+
# """
|
48 |
+
# Retrieves relevant passages using explicit cosine similarity calculation.
|
49 |
+
# """
|
50 |
+
# # Get query embedding
|
51 |
+
# query_embedding = db._embedding_function([query])[0]
|
52 |
+
|
53 |
+
# # Get all document embeddings
|
54 |
+
# all_docs = db.get(include=['embeddings', 'documents'])
|
55 |
+
# doc_embeddings = all_docs['embeddings']
|
56 |
+
# documents = all_docs['documents']
|
57 |
+
|
58 |
+
# # Calculate cosine similarity for each document
|
59 |
+
# similarities = []
|
60 |
+
# for doc_embedding in doc_embeddings:
|
61 |
+
# similarity = cosine_similarity(query_embedding, doc_embedding)
|
62 |
+
# similarities.append(similarity)
|
63 |
+
|
64 |
+
# # Sort documents by similarity
|
65 |
+
# doc_similarities = list(zip(documents, similarities))
|
66 |
+
# doc_similarities.sort(key=lambda x: x[1], reverse=True)
|
67 |
+
|
68 |
+
# # Take top n results
|
69 |
+
# top_results = doc_similarities[:n_results]
|
70 |
+
|
71 |
+
# # Print results for debugging
|
72 |
+
# print(f"Number of relevant passages retrieved: {len(top_results)}")
|
73 |
+
# for i, (doc, similarity) in enumerate(top_results):
|
74 |
+
# print(f"Passage {i+1} (Cosine Similarity: {similarity:.4f}): {doc[:100]}...")
|
75 |
+
|
76 |
+
# # Return just the documents
|
77 |
+
# return [doc for doc, _ in top_results]
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# in memory
|
86 |
+
|
87 |
+
|
88 |
+
# import chromadb
|
89 |
+
# from typing import List
|
90 |
+
# from gemini_embedding import GeminiEmbeddingFunction # Ensure this is correctly implemented
|
91 |
+
# import time
|
92 |
+
# from chromadb.config import Settings
|
93 |
+
|
94 |
+
# def create_chroma_db(chunks: List[str]):
|
95 |
+
# """Create and return an in-memory ChromaDB collection."""
|
96 |
+
# try:
|
97 |
+
# # Initialize in-memory ChromaDB with current recommended configuration
|
98 |
+
# client = chromadb.Client()
|
99 |
+
|
100 |
+
# # Create collection with unique name to avoid conflicts
|
101 |
+
# collection_name = f"temp_collection_{int(time.time())}"
|
102 |
+
# collection = client.create_collection(name=collection_name)
|
103 |
+
|
104 |
+
# # Add documents with unique IDs
|
105 |
+
# collection.add(
|
106 |
+
# documents=chunks,
|
107 |
+
# ids=[f"doc_{i}" for i in range(len(chunks))]
|
108 |
+
# )
|
109 |
+
|
110 |
+
# # Verify the data was added
|
111 |
+
# verify_count = collection.count()
|
112 |
+
# print(f"Verified: Added {verify_count} documents to collection {collection_name}")
|
113 |
+
|
114 |
+
# # Test query to ensure collection is working
|
115 |
+
# test_results = collection.query(
|
116 |
+
# query_texts=["test"],
|
117 |
+
# n_results=1
|
118 |
+
# )
|
119 |
+
# print("Verified: Collection is queryable")
|
120 |
+
|
121 |
+
# return collection
|
122 |
+
|
123 |
+
# except Exception as e:
|
124 |
+
# print(f"Error creating ChromaDB: {str(e)}")
|
125 |
+
# return None
|
126 |
+
|
127 |
+
# def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
128 |
+
# """
|
129 |
+
# Retrieves relevant passages using ChromaDB's similarity search.
|
130 |
+
# """
|
131 |
+
# try:
|
132 |
+
# if db is None:
|
133 |
+
# print("Database not initialized")
|
134 |
+
# return []
|
135 |
+
|
136 |
+
# # Verify collection has documents
|
137 |
+
# count = db.count()
|
138 |
+
# if count == 0:
|
139 |
+
# print("Collection is empty")
|
140 |
+
# return []
|
141 |
+
|
142 |
+
# # Query the database
|
143 |
+
# results = db.query(
|
144 |
+
# query_texts=[query],
|
145 |
+
# n_results=min(n_results, count) # Ensure we don't request more than we have
|
146 |
+
# )
|
147 |
+
|
148 |
+
# # Ensure results exist
|
149 |
+
# if not results["documents"]:
|
150 |
+
# print("No relevant passages found.")
|
151 |
+
# return []
|
152 |
+
|
153 |
+
# documents = results["documents"][0] # First result batch
|
154 |
+
# distances = results["distances"][0] # Corresponding distances
|
155 |
+
|
156 |
+
# # Debug output
|
157 |
+
# print(f"Number of relevant passages retrieved: {len(documents)}")
|
158 |
+
# for i, (doc, distance) in enumerate(zip(documents, distances)):
|
159 |
+
# similarity = 1 - distance # Convert distance to similarity
|
160 |
+
# print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
|
161 |
+
|
162 |
+
# return documents
|
163 |
+
# except Exception as e:
|
164 |
+
# print(f"Error in get_relevant_passage: {str(e)}")
|
165 |
+
# return []
|
166 |
+
|
167 |
+
|
168 |
+
import chromadb
|
169 |
+
from chromadb.config import Settings
|
170 |
+
from typing import List
|
171 |
+
import os
|
172 |
+
from gemini_embedding import GeminiEmbeddingFunction
|
173 |
+
import datetime
|
174 |
+
embedding_function = GeminiEmbeddingFunction()
|
175 |
+
|
176 |
+
def create_chroma_db(documents: List[str]):
|
177 |
+
"""
|
178 |
+
Creates a persistent Chroma database using the provided documents.
|
179 |
+
"""
|
180 |
+
# Create a persistent directory for ChromaDB
|
181 |
+
persist_directory = "chroma_db"
|
182 |
+
os.makedirs(persist_directory, exist_ok=True)
|
183 |
+
|
184 |
+
# Initialize the client with persistence
|
185 |
+
chroma_client = chromadb.PersistentClient(
|
186 |
+
path=persist_directory,
|
187 |
+
)
|
188 |
+
|
189 |
+
# Get or create collection
|
190 |
+
try:
|
191 |
+
# Try to get existing collection
|
192 |
+
db = chroma_client.get_collection(
|
193 |
+
name="document_collection",
|
194 |
+
embedding_function=embedding_function
|
195 |
+
)
|
196 |
+
# Clear existing documents
|
197 |
+
db.delete(db.get()["ids"])
|
198 |
+
except:
|
199 |
+
# Create new collection if it doesn't exist
|
200 |
+
db = chroma_client.create_collection(
|
201 |
+
name="document_collection",
|
202 |
+
embedding_function=embedding_function
|
203 |
+
)
|
204 |
+
|
205 |
+
# Add documents in batches to avoid memory issues
|
206 |
+
batch_size = 20
|
207 |
+
for i in range(0, len(documents), batch_size):
|
208 |
+
batch = documents[i:i + batch_size]
|
209 |
+
db.add(
|
210 |
+
documents=batch,
|
211 |
+
ids=[f"doc_{j}" for j in range(i, i + len(batch))]
|
212 |
+
)
|
213 |
+
|
214 |
+
return db
|
215 |
+
|
216 |
+
def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
217 |
+
start_time = datetime.datetime.now()
|
218 |
+
print(f"{start_time}: Starting ChromaDB query for question: {query[:50]}...") # Log query start
|
219 |
+
|
220 |
+
try:
|
221 |
+
results = db.query(
|
222 |
+
query_texts=[query],
|
223 |
+
n_results=min(n_results, db.count()),
|
224 |
+
include=['documents', 'distances']
|
225 |
+
)
|
226 |
+
end_time = datetime.datetime.now()
|
227 |
+
print(f"{end_time}: ChromaDB query finished. Time taken: {end_time - start_time}") # Log the time taken
|
228 |
+
# ... (rest of your get_relevant_passage function remains the same)
|
229 |
+
|
230 |
+
# Ensure results exist and contain at least one document
|
231 |
+
if not results or 'documents' not in results or not results['documents'] or not results['documents'][0]:
|
232 |
+
print("No relevant passages found.")
|
233 |
+
return []
|
234 |
+
|
235 |
+
# Extract valid results
|
236 |
+
documents = results['documents'][0] # List of retrieved documents
|
237 |
+
distances = results['distances'][0] # Corresponding similarity scores
|
238 |
+
|
239 |
+
# Debugging output
|
240 |
+
print(f"Number of relevant passages retrieved: {len(documents)}")
|
241 |
+
for i, (doc, distance) in enumerate(zip(documents, distances)):
|
242 |
+
similarity = 1 - distance # Convert distance to similarity score
|
243 |
+
print(f"Passage {i+1} (Similarity: {similarity:.4f}): {doc[:100]}...")
|
244 |
+
|
245 |
+
return documents # Return only valid results
|
246 |
+
except Exception as e:
|
247 |
+
print(f"Error in get_relevant_passage: {str(e)}")
|
248 |
+
return []
|
249 |
+
|
gemini_embedding.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import google.generativeai as genai
|
3 |
+
from chromadb.api.types import Documents, Embeddings
|
4 |
+
from chromadb import EmbeddingFunction
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
gemini_api_key = os.environ["GEMINI_API_KEY"]
|
9 |
+
|
10 |
+
class GeminiEmbeddingFunction(EmbeddingFunction):
|
11 |
+
"""
|
12 |
+
Custom embedding function using Gemini AI API.
|
13 |
+
"""
|
14 |
+
def __call__(self, input: Documents) -> Embeddings:
|
15 |
+
if not gemini_api_key:
|
16 |
+
raise ValueError("Gemini API Key not provided. Please set GEMINI_API_KEY as an environment variable.")
|
17 |
+
genai.configure(api_key=gemini_api_key)
|
18 |
+
model = "models/text-embedding-004"
|
19 |
+
return genai.embed_content(model=model, content=input, task_type="retrieval_document")["embedding"]
|
pdf_utils.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import pdfplumber
|
3 |
+
from typing import List, Optional
|
4 |
+
import textract
|
5 |
+
from docx import Document
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
# Configure logging
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
def clean_text(text: str) -> str:
|
16 |
+
"""Clean extracted text by removing extra whitespace and invalid characters."""
|
17 |
+
text = re.sub(r'\s+', ' ', text) # Remove multiple spaces
|
18 |
+
text = ''.join(char for char in text if char.isprintable() or char == '\n') # Remove non-printable characters
|
19 |
+
text = re.sub(r'\n\s*\n', '\n\n', text) # Remove multiple newlines
|
20 |
+
return text.strip()
|
21 |
+
|
22 |
+
def extract_text_from_pdf(pdf_path: str) -> Optional[str]:
|
23 |
+
"""
|
24 |
+
Extract text from PDF using pdfplumber.
|
25 |
+
"""
|
26 |
+
extracted_text = []
|
27 |
+
try:
|
28 |
+
with pdfplumber.open(pdf_path) as pdf:
|
29 |
+
for page_num, page in enumerate(pdf.pages, 1):
|
30 |
+
try:
|
31 |
+
page_text = page.extract_text()
|
32 |
+
if page_text:
|
33 |
+
extracted_text.append(page_text)
|
34 |
+
else:
|
35 |
+
logger.warning(f"No text extracted from page {page_num}")
|
36 |
+
except Exception as e:
|
37 |
+
logger.error(f"Error extracting text from page {page_num}: {e}")
|
38 |
+
continue
|
39 |
+
|
40 |
+
if not extracted_text:
|
41 |
+
logger.warning("No text was extracted from any page of the PDF")
|
42 |
+
return None
|
43 |
+
|
44 |
+
return clean_text('\n'.join(extracted_text))
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"Failed to process PDF {pdf_path}: {e}")
|
47 |
+
return None
|
48 |
+
|
49 |
+
def extract_text_from_docx(docx_path: str) -> Optional[str]:
|
50 |
+
"""
|
51 |
+
Extract text from DOCX with enhanced error handling.
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
doc = Document(docx_path)
|
55 |
+
text = '\n'.join(para.text for para in doc.paragraphs if para.text.strip())
|
56 |
+
return clean_text(text) if text else None
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"Failed to process DOCX {docx_path}: {e}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
import tempfile
|
62 |
+
|
63 |
+
def extract_text_from_file(uploaded_file) -> Optional[str]:
|
64 |
+
"""
|
65 |
+
Extract text from various file types with enhanced error handling and logging.
|
66 |
+
If file is uploaded as file-like object, save it temporarily.
|
67 |
+
"""
|
68 |
+
if isinstance(uploaded_file, str): # Assuming file_path is a string for direct file handling
|
69 |
+
file_path = uploaded_file
|
70 |
+
else: # Handle file-like objects (e.g., uploaded files)
|
71 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
72 |
+
temp_file.write(uploaded_file.read()) # Save file contents temporarily
|
73 |
+
file_path = temp_file.name # Temporary file path
|
74 |
+
|
75 |
+
if not os.path.exists(file_path):
|
76 |
+
logger.error(f"File not found: {file_path}")
|
77 |
+
return None
|
78 |
+
|
79 |
+
_, file_extension = os.path.splitext(file_path)
|
80 |
+
file_extension = file_extension.lower()
|
81 |
+
|
82 |
+
try:
|
83 |
+
if file_extension == ".pdf":
|
84 |
+
text = extract_text_from_pdf(file_path)
|
85 |
+
elif file_extension == ".docx":
|
86 |
+
text = extract_text_from_docx(file_path)
|
87 |
+
elif file_extension == ".txt":
|
88 |
+
try:
|
89 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
90 |
+
text = clean_text(file.read())
|
91 |
+
except UnicodeDecodeError:
|
92 |
+
with open(file_path, "r", encoding="latin-1") as file:
|
93 |
+
text = clean_text(file.read())
|
94 |
+
else:
|
95 |
+
text = clean_text(textract.process(file_path).decode("utf-8"))
|
96 |
+
|
97 |
+
if not text:
|
98 |
+
logger.warning(f"No text content extracted from {file_path}")
|
99 |
+
return None
|
100 |
+
|
101 |
+
return text
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error extracting text from {file_path}: {e}")
|
105 |
+
return None
|
106 |
+
|
107 |
+
|
108 |
+
def split_text(text: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> List[str]:
|
109 |
+
"""
|
110 |
+
Split text into chunks with improved handling and validation.
|
111 |
+
"""
|
112 |
+
if not text:
|
113 |
+
logger.warning("Empty text provided for splitting")
|
114 |
+
return []
|
115 |
+
|
116 |
+
try:
|
117 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
118 |
+
chunk_size=chunk_size,
|
119 |
+
chunk_overlap=chunk_overlap,
|
120 |
+
length_function=len,
|
121 |
+
is_separator_regex=False
|
122 |
+
)
|
123 |
+
|
124 |
+
splits = text_splitter.split_text(text)
|
125 |
+
|
126 |
+
logger.info(f"Split text into {len(splits)} chunks")
|
127 |
+
|
128 |
+
return splits
|
129 |
+
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(f"Error splitting text: {e}")
|
132 |
+
return []
|
133 |
+
|
134 |
+
# Example usage
|
135 |
+
if __name__ == "__main__":
|
136 |
+
sample_file = "/Users/jessicawin/Downloads/github-recovery-codes.txt"
|
137 |
+
if os.path.exists(sample_file):
|
138 |
+
file_text = extract_text_from_file(sample_file)
|
139 |
+
if file_text:
|
140 |
+
chunks = split_text(file_text)
|
141 |
+
print(f"Successfully processed file into {len(chunks)} chunks")
|
query_handler.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import google.generativeai as genai
|
3 |
+
from chroma_db_utils import get_relevant_passage
|
4 |
+
import time
|
5 |
+
import datetime
|
6 |
+
import google.api_core.exceptions
|
7 |
+
|
8 |
+
# Constants
|
9 |
+
MAX_RETRIES = 3
|
10 |
+
RETRY_DELAY = 1 # Initial delay in seconds
|
11 |
+
MODEL_NAME = "gemini-1.5-flash"
|
12 |
+
REQUESTS_PER_MINUTE = 3 # Free tier limit
|
13 |
+
REQUEST_INTERVAL = 60 / REQUESTS_PER_MINUTE # Ensures we stay within the rate limit
|
14 |
+
|
15 |
+
def make_rag_prompt(query: str, relevant_passage: str) -> str:
|
16 |
+
"""
|
17 |
+
Creates a prompt for the RAG model.
|
18 |
+
"""
|
19 |
+
escaped = relevant_passage.replace("'", "").replace('"', "").replace("\n", " ")
|
20 |
+
prompt = f'''
|
21 |
+
You are a helpful and informative bot that answers questions using the REFERENCE TEXT below.
|
22 |
+
If the REFERENCE TEXT is irrelevant to the question, say "I cannot answer this question based on the provided information."
|
23 |
+
|
24 |
+
QUESTION: {query}
|
25 |
+
|
26 |
+
REFERENCE TEXT:
|
27 |
+
{escaped}
|
28 |
+
|
29 |
+
ANSWER:
|
30 |
+
'''
|
31 |
+
return prompt
|
32 |
+
|
33 |
+
def generate_answer(prompt: str) -> str:
|
34 |
+
"""
|
35 |
+
Calls the Gemini API with retries and rate limiting.
|
36 |
+
"""
|
37 |
+
gemini_api_key = os.getenv("GEMINI_API_KEY")
|
38 |
+
if not gemini_api_key:
|
39 |
+
raise ValueError("Gemini API Key not provided.")
|
40 |
+
|
41 |
+
genai.configure(api_key=gemini_api_key)
|
42 |
+
model = genai.GenerativeModel(MODEL_NAME)
|
43 |
+
|
44 |
+
for attempt in range(MAX_RETRIES):
|
45 |
+
start_time = datetime.datetime.now()
|
46 |
+
print(f"{start_time}: Making Gemini API request (attempt {attempt + 1}/{MAX_RETRIES})...")
|
47 |
+
try:
|
48 |
+
response = model.generate_content(prompt)
|
49 |
+
end_time = datetime.datetime.now()
|
50 |
+
print(f"{end_time}: Gemini API request successful. Time taken: {end_time - start_time}")
|
51 |
+
|
52 |
+
# Enforce rate limiting
|
53 |
+
time.sleep(REQUEST_INTERVAL)
|
54 |
+
return response.text
|
55 |
+
except google.api_core.exceptions.ResourceExhausted as e:
|
56 |
+
if e.code == 429: # Too Many Requests
|
57 |
+
delay = RETRY_DELAY * (2 ** attempt) # Exponential backoff
|
58 |
+
print(f"Rate limit hit. Retrying in {delay} seconds (attempt {attempt + 1}/{MAX_RETRIES})...")
|
59 |
+
time.sleep(delay)
|
60 |
+
else:
|
61 |
+
raise # Re-raise other exceptions
|
62 |
+
|
63 |
+
raise Exception("Max retries exceeded for Gemini API request.")
|
64 |
+
|
65 |
+
def handle_query(query: str, db, n_results: int = 5) -> str:
|
66 |
+
"""
|
67 |
+
Handles a user query by retrieving relevant passages and generating an answer.
|
68 |
+
"""
|
69 |
+
relevant_passages = get_relevant_passage(query, db, n_results)
|
70 |
+
relevant_passage_str = " ".join(relevant_passages)
|
71 |
+
prompt = make_rag_prompt(query, relevant_passage=relevant_passage_str)
|
72 |
+
return generate_answer(prompt)
|
requirement.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
google-generativeai>=0.3.0
|
2 |
+
chromadb
|
3 |
+
pdfplumber
|
4 |
+
python-docx>=0.8.11
|
5 |
+
textract>=1.6.5
|
6 |
+
langchain>=0.1.0
|
7 |
+
chromadb>=0.4.0
|
8 |
+
numpy>=1.21.0
|
9 |
+
python-dotenv>=0.19.0
|
10 |
+
streamlit>=1.18.0
|
11 |
+
typing>=3.7.4
|
12 |
+
warnings>=0.1.0
|
13 |
+
logging>=0.5.0
|