Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,26 @@
|
|
1 |
import streamlit as st
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
-
from sentence_transformers import SentenceTransformer
|
5 |
import requests
|
6 |
import os
|
7 |
import torch
|
8 |
import pickle
|
9 |
-
from tqdm import tqdm
|
10 |
-
from googleapiclient.discovery import build
|
11 |
-
from google_auth_oauthlib.flow import InstalledAppFlow
|
12 |
-
from google.auth.transport.requests import Request
|
13 |
-
from google.oauth2.credentials import Credentials
|
14 |
import base64
|
15 |
import re
|
16 |
from pyngrok import ngrok
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# ===============================
|
19 |
# 1. Streamlit App Configuration
|
20 |
# ===============================
|
21 |
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide")
|
22 |
-
st.title("
|
23 |
|
24 |
# ===============================
|
25 |
# 2. Gmail Authentication Configuration
|
@@ -42,6 +43,12 @@ if "embeddings" not in st.session_state:
|
|
42 |
if "vector_store" not in st.session_state:
|
43 |
st.session_state.vector_store = None
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def count_tokens(text):
|
46 |
return len(text.split())
|
47 |
|
@@ -57,7 +64,9 @@ def reset_session_state():
|
|
57 |
st.session_state.data_chunks = []
|
58 |
st.session_state.embeddings = None
|
59 |
st.session_state.vector_store = None
|
60 |
-
|
|
|
|
|
61 |
if os.path.exists(filename):
|
62 |
os.remove(filename)
|
63 |
|
@@ -65,6 +74,7 @@ def authenticate_gmail(credentials_file):
|
|
65 |
creds = None
|
66 |
if os.path.exists('token.json'):
|
67 |
try:
|
|
|
68 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
69 |
if creds and creds.valid:
|
70 |
st.session_state.creds = creds
|
@@ -132,18 +142,19 @@ def extract_email_body(payload):
|
|
132 |
return ""
|
133 |
|
134 |
def combine_email_text(email):
|
|
|
135 |
parts = []
|
136 |
if email.get("sender"):
|
137 |
-
parts.append(
|
138 |
if email.get("to"):
|
139 |
-
parts.append(
|
140 |
if email.get("date"):
|
141 |
-
parts.append(
|
142 |
if email.get("subject"):
|
143 |
-
parts.append(
|
144 |
if email.get("body"):
|
145 |
-
parts.append(
|
146 |
-
return "
|
147 |
|
148 |
def create_chunks_from_gmail(service, label):
|
149 |
try:
|
@@ -152,8 +163,7 @@ def create_chunks_from_gmail(service, label):
|
|
152 |
messages.extend(result.get('messages', []))
|
153 |
while 'nextPageToken' in result:
|
154 |
token = result["nextPageToken"]
|
155 |
-
result = service.users().messages().list(userId='me', labelIds=[label],
|
156 |
-
maxResults=500, pageToken=token).execute()
|
157 |
messages.extend(result.get('messages', []))
|
158 |
|
159 |
data_chunks = []
|
@@ -175,22 +185,28 @@ def create_chunks_from_gmail(service, label):
|
|
175 |
email_dict['date'] = header.get('value', '')
|
176 |
email_dict['body'] = extract_email_body(msg_data.get('payload', {}))
|
177 |
data_chunks.append(email_dict)
|
178 |
-
progress_bar.progress((idx + 1) / total)
|
179 |
-
st.session_state.data_chunks
|
180 |
-
st.success(f"β
Data chunks created successfully from
|
181 |
-
# Save chunks locally for future use.
|
182 |
-
with open("data_chunks.pkl", "wb") as f:
|
183 |
-
pickle.dump(data_chunks, f)
|
184 |
except Exception as e:
|
185 |
-
st.error(f"β Error creating chunks from Gmail: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def embed_emails(email_chunks):
|
188 |
st.header("π Embedding Data and Creating Vector Store")
|
|
|
189 |
with st.spinner('π Embedding data...'):
|
190 |
try:
|
191 |
-
embed_model =
|
192 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
193 |
-
embed_model.to(device)
|
194 |
combined_texts = [combine_email_text(email) for email in email_chunks]
|
195 |
batch_size = 64
|
196 |
embeddings = []
|
@@ -203,6 +219,8 @@ def embed_emails(email_chunks):
|
|
203 |
device=device
|
204 |
)
|
205 |
embeddings.append(batch_embeddings)
|
|
|
|
|
206 |
embeddings = np.vstack(embeddings)
|
207 |
faiss.normalize_L2(embeddings)
|
208 |
st.session_state.embeddings = embeddings
|
@@ -211,218 +229,255 @@ def embed_emails(email_chunks):
|
|
211 |
index.add(embeddings)
|
212 |
st.session_state.vector_store = index
|
213 |
st.success("β
Data embedding and vector store created successfully!")
|
214 |
-
# Save embeddings and index to disk.
|
215 |
-
with open('embeddings.pkl', 'wb') as f:
|
216 |
-
pickle.dump(embeddings, f)
|
217 |
-
faiss.write_index(index, 'vector_store.index')
|
218 |
except Exception as e:
|
219 |
st.error(f"β Error during embedding: {e}")
|
220 |
|
221 |
-
|
222 |
-
|
223 |
-
with open('embeddings.pkl', 'wb') as f:
|
224 |
-
pickle.dump(st.session_state.embeddings, f)
|
225 |
-
faiss.write_index(st.session_state.vector_store, 'vector_store.index')
|
226 |
-
st.success("πΎ Embeddings and vector store saved successfully!")
|
227 |
-
except Exception as e:
|
228 |
-
st.error(f"β Error saving embeddings and vector store: {e}")
|
229 |
-
|
230 |
-
def load_embeddings_and_index():
|
231 |
try:
|
232 |
-
|
233 |
-
st.session_state.
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
except Exception as e:
|
237 |
-
st.error(f"β Error
|
238 |
-
|
239 |
-
def load_chunks():
|
240 |
-
try:
|
241 |
-
with open("data_chunks.pkl", "rb") as f:
|
242 |
-
st.session_state.data_chunks = pickle.load(f)
|
243 |
-
st.success("π Email chunks loaded successfully!")
|
244 |
-
except Exception as e:
|
245 |
-
st.error(f"β Error loading email chunks: {e}")
|
246 |
|
247 |
# ===============================
|
248 |
-
# 5. Handling User Queries
|
249 |
# ===============================
|
250 |
def preprocess_query(query):
|
251 |
return query.lower().strip()
|
252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
def handle_user_query():
|
254 |
-
st.header("π¬ Let's
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
if
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
return
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
280 |
)
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
# Boost candidates if sender or "to" field contains query tokens (e.g., email addresses).
|
288 |
-
query_tokens = re.findall(r'\S+@\S+', user_query)
|
289 |
-
if query_tokens:
|
290 |
-
for i in range(len(candidates)):
|
291 |
-
candidate_email_str = (
|
292 |
-
(candidates[i][0].get("sender", "") + " " + candidates[i][0].get("to", "")).lower()
|
293 |
-
)
|
294 |
-
for token in query_tokens:
|
295 |
-
if token.lower() in candidate_email_str:
|
296 |
-
candidates[i] = (candidates[i][0], max(candidates[i][1], 1.0))
|
297 |
-
filtered_candidates = []
|
298 |
-
for candidate, score in candidates:
|
299 |
-
candidate_text = combine_email_text(candidate).lower()
|
300 |
-
if any(token.lower() in candidate_text for token in query_tokens):
|
301 |
-
filtered_candidates.append((candidate, score))
|
302 |
-
if filtered_candidates:
|
303 |
-
candidates = filtered_candidates
|
304 |
-
else:
|
305 |
-
st.info("No candidate emails contain the query token(s) exactly. Proceeding with all candidates.")
|
306 |
-
|
307 |
-
candidates.sort(key=lambda x: x[1], reverse=True)
|
308 |
-
if not candidates:
|
309 |
-
st.subheader("π AI Response:")
|
310 |
-
st.write("β οΈ No documents found.")
|
311 |
-
return
|
312 |
-
if candidates[0][1] < SIMILARITY_THRESHOLD:
|
313 |
-
st.subheader("π AI Response:")
|
314 |
-
st.write("β οΈ No document strongly matches your query. Try refining your query.")
|
315 |
-
return
|
316 |
-
|
317 |
-
# Re-rank candidates using the cross-encoder.
|
318 |
-
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
319 |
-
candidate_pairs = [(user_query, combine_email_text(candidate[0])) for candidate in candidates]
|
320 |
-
rerank_scores = cross_encoder.predict(candidate_pairs)
|
321 |
-
reranked_candidates = [(candidates[i][0], rerank_scores[i]) for i in range(len(candidates))]
|
322 |
-
reranked_candidates.sort(key=lambda x: x[1], reverse=True)
|
323 |
-
retrieved_emails = [email for email, score in reranked_candidates]
|
324 |
-
retrieved_scores = [score for email, score in reranked_candidates]
|
325 |
-
average_similarity = np.mean(retrieved_scores)
|
326 |
-
|
327 |
-
# Build the final context string.
|
328 |
-
context_str = "\n\n".join([combine_email_text(email) for email in retrieved_emails])
|
329 |
-
MAX_CONTEXT_TOKENS = 500
|
330 |
-
context_tokens = context_str.split()
|
331 |
-
if len(context_tokens) > MAX_CONTEXT_TOKENS:
|
332 |
-
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS])
|
333 |
-
|
334 |
-
payload = {
|
335 |
-
"model": "llama3-8b-8192", # Adjust as needed.
|
336 |
-
"messages": [
|
337 |
-
{"role": "system", "content": f"Use the following context:\n{context_str}"},
|
338 |
-
{"role": "user", "content": user_query}
|
339 |
-
]
|
340 |
-
}
|
341 |
-
api_key = "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" # Replace with your API key.
|
342 |
-
url = "https://api.groq.com/openai/v1/chat/completions"
|
343 |
-
headers = {
|
344 |
-
"Authorization": f"Bearer {api_key}",
|
345 |
-
"Content-Type": "application/json"
|
346 |
-
}
|
347 |
-
response = requests.post(url, headers=headers, json=payload)
|
348 |
-
if response.status_code == 200:
|
349 |
-
response_json = response.json()
|
350 |
-
generated_text = response_json["choices"][0]["message"]["content"]
|
351 |
-
st.subheader("π AI Response:")
|
352 |
-
st.write(generated_text)
|
353 |
-
st.write(f"Average Re-Ranked Score: {average_similarity:.4f}")
|
354 |
-
else:
|
355 |
-
st.error(f"β Error from LLM API: {response.status_code} - {response.text}")
|
356 |
-
except Exception as e:
|
357 |
-
st.error(f"β An error occurred during processing: {e}")
|
358 |
|
359 |
# ===============================
|
360 |
# 6. Main Application Logic
|
361 |
# ===============================
|
362 |
def main():
|
363 |
st.sidebar.header("π Gmail Authentication")
|
364 |
-
credentials_file = st.sidebar.file_uploader("π Upload
|
365 |
-
if credentials_file and st.sidebar.button("π Authenticate"):
|
366 |
-
reset_session_state()
|
367 |
-
with open("credentials.json", "wb") as f:
|
368 |
-
f.write(credentials_file.getbuffer())
|
369 |
-
authenticate_gmail("credentials.json")
|
370 |
-
|
371 |
-
# Option to load previously saved email chunks.
|
372 |
-
chunks_file = st.sidebar.file_uploader("π Upload saved email chunks (data_chunks.pkl)", type=["pkl"])
|
373 |
-
if chunks_file:
|
374 |
-
try:
|
375 |
-
st.session_state.data_chunks = pickle.load(chunks_file)
|
376 |
-
st.success("π Email chunks loaded successfully from upload!")
|
377 |
-
except Exception as e:
|
378 |
-
st.error(f"β Error loading uploaded email chunks: {e}")
|
379 |
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
|
398 |
-
if st.session_state.authenticated:
|
399 |
st.sidebar.success("β
You are authenticated!")
|
400 |
-
st.
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
if st.
|
405 |
service = build('gmail', 'v1', credentials=st.session_state.creds)
|
406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
if st.session_state.data_chunks:
|
408 |
embed_emails(st.session_state.data_chunks)
|
409 |
-
if
|
410 |
-
with st.expander("πΎ
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
st.success("πΎ Email chunks saved to disk!")
|
416 |
-
except Exception as e:
|
417 |
-
st.error(f"β Error saving email chunks: {e}")
|
418 |
-
if st.button("πΎ Save Embeddings & Vector Store"):
|
419 |
-
save_embeddings_and_index()
|
420 |
-
if (st.session_state.vector_store is not None and
|
421 |
-
st.session_state.embeddings is not None and
|
422 |
-
st.session_state.data_chunks is not None):
|
423 |
-
handle_user_query()
|
424 |
-
else:
|
425 |
-
st.warning("β οΈ You are not authenticated yet. Please authenticate to access your Gmail data.")
|
426 |
|
427 |
if __name__ == "__main__":
|
428 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
import faiss
|
3 |
import numpy as np
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
import requests
|
6 |
import os
|
7 |
import torch
|
8 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
9 |
import base64
|
10 |
import re
|
11 |
from pyngrok import ngrok
|
12 |
+
from googleapiclient.discovery import build
|
13 |
+
from google_auth_oauthlib.flow import InstalledAppFlow
|
14 |
+
from google.auth.transport.requests import Request
|
15 |
+
import subprocess
|
16 |
+
import time
|
17 |
+
import sys
|
18 |
|
19 |
# ===============================
|
20 |
# 1. Streamlit App Configuration
|
21 |
# ===============================
|
22 |
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide")
|
23 |
+
st.title("π¬ Turn Emails into ConversationsβEffortless Chat with Your Inbox! π©")
|
24 |
|
25 |
# ===============================
|
26 |
# 2. Gmail Authentication Configuration
|
|
|
43 |
if "vector_store" not in st.session_state:
|
44 |
st.session_state.vector_store = None
|
45 |
|
46 |
+
# For storing candidate context details.
|
47 |
+
if "candidate_context" not in st.session_state:
|
48 |
+
st.session_state.candidate_context = None
|
49 |
+
if "raw_candidates" not in st.session_state:
|
50 |
+
st.session_state.raw_candidates = None
|
51 |
+
|
52 |
def count_tokens(text):
|
53 |
return len(text.split())
|
54 |
|
|
|
64 |
st.session_state.data_chunks = []
|
65 |
st.session_state.embeddings = None
|
66 |
st.session_state.vector_store = None
|
67 |
+
st.session_state.candidate_context = None
|
68 |
+
st.session_state.raw_candidates = None
|
69 |
+
for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index", "vector_database.pkl"]:
|
70 |
if os.path.exists(filename):
|
71 |
os.remove(filename)
|
72 |
|
|
|
74 |
creds = None
|
75 |
if os.path.exists('token.json'):
|
76 |
try:
|
77 |
+
from google.oauth2.credentials import Credentials
|
78 |
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
|
79 |
if creds and creds.valid:
|
80 |
st.session_state.creds = creds
|
|
|
142 |
return ""
|
143 |
|
144 |
def combine_email_text(email):
|
145 |
+
# Build the complete email text by joining parts with HTML line breaks.
|
146 |
parts = []
|
147 |
if email.get("sender"):
|
148 |
+
parts.append("From: " + email["sender"])
|
149 |
if email.get("to"):
|
150 |
+
parts.append("To: " + email["to"])
|
151 |
if email.get("date"):
|
152 |
+
parts.append("Date: " + email["date"])
|
153 |
if email.get("subject"):
|
154 |
+
parts.append("Subject: " + email["subject"])
|
155 |
if email.get("body"):
|
156 |
+
parts.append("Body: " + email["body"])
|
157 |
+
return "<br>".join(parts)
|
158 |
|
159 |
def create_chunks_from_gmail(service, label):
|
160 |
try:
|
|
|
163 |
messages.extend(result.get('messages', []))
|
164 |
while 'nextPageToken' in result:
|
165 |
token = result["nextPageToken"]
|
166 |
+
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500, pageToken=token).execute()
|
|
|
167 |
messages.extend(result.get('messages', []))
|
168 |
|
169 |
data_chunks = []
|
|
|
185 |
email_dict['date'] = header.get('value', '')
|
186 |
email_dict['body'] = extract_email_body(msg_data.get('payload', {}))
|
187 |
data_chunks.append(email_dict)
|
188 |
+
progress_bar.progress(min((idx + 1) / total, 1.0))
|
189 |
+
st.session_state.data_chunks.extend(data_chunks)
|
190 |
+
st.success(f"β
Data chunks created successfully from {label}! Total emails processed for this label: {len(data_chunks)}")
|
|
|
|
|
|
|
191 |
except Exception as e:
|
192 |
+
st.error(f"β Error creating chunks from Gmail for label {label}: {e}")
|
193 |
+
|
194 |
+
# -------------------------------
|
195 |
+
# Cached model loaders for efficiency
|
196 |
+
# -------------------------------
|
197 |
+
@st.cache_resource
|
198 |
+
def get_embed_model():
|
199 |
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
200 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
201 |
+
model.to(device)
|
202 |
+
return model, device
|
203 |
|
204 |
def embed_emails(email_chunks):
|
205 |
st.header("π Embedding Data and Creating Vector Store")
|
206 |
+
progress_bar = st.progress(0)
|
207 |
with st.spinner('π Embedding data...'):
|
208 |
try:
|
209 |
+
embed_model, device = get_embed_model()
|
|
|
|
|
210 |
combined_texts = [combine_email_text(email) for email in email_chunks]
|
211 |
batch_size = 64
|
212 |
embeddings = []
|
|
|
219 |
device=device
|
220 |
)
|
221 |
embeddings.append(batch_embeddings)
|
222 |
+
progress_value = min((i + batch_size) / len(combined_texts), 1.0)
|
223 |
+
progress_bar.progress(progress_value)
|
224 |
embeddings = np.vstack(embeddings)
|
225 |
faiss.normalize_L2(embeddings)
|
226 |
st.session_state.embeddings = embeddings
|
|
|
229 |
index.add(embeddings)
|
230 |
st.session_state.vector_store = index
|
231 |
st.success("β
Data embedding and vector store created successfully!")
|
|
|
|
|
|
|
|
|
232 |
except Exception as e:
|
233 |
st.error(f"β Error during embedding: {e}")
|
234 |
|
235 |
+
# New function to save the entire vector database as a single pickle file.
|
236 |
+
def save_vector_database():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
try:
|
238 |
+
vector_db = {
|
239 |
+
"vector_store": st.session_state.vector_store,
|
240 |
+
"embeddings": st.session_state.embeddings,
|
241 |
+
"data_chunks": st.session_state.data_chunks
|
242 |
+
}
|
243 |
+
db_data = pickle.dumps(vector_db)
|
244 |
+
st.download_button(
|
245 |
+
label="Download Vector Database",
|
246 |
+
data=db_data,
|
247 |
+
file_name="vector_database.pkl",
|
248 |
+
mime="application/octet-stream"
|
249 |
+
)
|
250 |
except Exception as e:
|
251 |
+
st.error(f"β Error saving vector database: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
# ===============================
|
254 |
+
# 5. Handling User Queries (User-Controlled Threshold)
|
255 |
# ===============================
|
256 |
def preprocess_query(query):
|
257 |
return query.lower().strip()
|
258 |
|
259 |
+
def process_candidate_emails(query, similarity_threshold):
|
260 |
+
"""
|
261 |
+
Process the query by computing its embedding, searching the vector store,
|
262 |
+
filtering candidates based on a similarity threshold, and building a context string.
|
263 |
+
"""
|
264 |
+
TOP_K = 20 # Increased to allow for threshold filtering
|
265 |
+
|
266 |
+
# Reset candidate context for each query
|
267 |
+
st.session_state.candidate_context = None
|
268 |
+
st.session_state.raw_candidates = None
|
269 |
+
|
270 |
+
if st.session_state.vector_store is None:
|
271 |
+
st.error("β Please process your email data or load a saved vector database first.")
|
272 |
+
return
|
273 |
+
|
274 |
+
with st.spinner('π Processing your query...'):
|
275 |
+
try:
|
276 |
+
embed_model, device = get_embed_model()
|
277 |
+
processed_query = preprocess_query(query)
|
278 |
+
query_embedding = embed_model.encode(
|
279 |
+
[processed_query],
|
280 |
+
convert_to_numpy=True,
|
281 |
+
show_progress_bar=False,
|
282 |
+
device=device
|
283 |
+
)
|
284 |
+
faiss.normalize_L2(query_embedding)
|
285 |
+
|
286 |
+
# Debug: Verify the type of vector_store
|
287 |
+
st.write(f"Vector Store Type: {type(st.session_state.vector_store)}")
|
288 |
+
|
289 |
+
# Perform search
|
290 |
+
distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K)
|
291 |
+
candidates = []
|
292 |
+
for idx, sim in zip(indices[0], distances[0]):
|
293 |
+
# Include candidate only if similarity meets the threshold
|
294 |
+
if sim >= similarity_threshold:
|
295 |
+
candidates.append((st.session_state.data_chunks[idx], sim))
|
296 |
+
if not candidates:
|
297 |
+
st.write("β οΈ No matching embeddings found for your query with the selected threshold.")
|
298 |
+
return
|
299 |
+
|
300 |
+
# Build the context string by concatenating all matching email texts using HTML breaks.
|
301 |
+
context_str = ""
|
302 |
+
for candidate, sim in candidates:
|
303 |
+
context_str += combine_email_text(candidate) + "<br><br>"
|
304 |
+
|
305 |
+
# Optionally limit context size.
|
306 |
+
MAX_CONTEXT_TOKENS = 500
|
307 |
+
context_tokens = context_str.split()
|
308 |
+
if len(context_tokens) > MAX_CONTEXT_TOKENS:
|
309 |
+
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS])
|
310 |
+
|
311 |
+
st.session_state.candidate_context = context_str
|
312 |
+
st.session_state.raw_candidates = candidates
|
313 |
+
st.success("β
Candidates retrieved and context built!")
|
314 |
+
except Exception as e:
|
315 |
+
st.error(f"β An error occurred during processing: {e}")
|
316 |
+
|
317 |
+
def call_llm_api(query):
|
318 |
+
"""
|
319 |
+
Send the user's query along with the concatenated matching email texts (context)
|
320 |
+
to the LLM API and display the AI response.
|
321 |
+
"""
|
322 |
+
if not st.session_state.candidate_context:
|
323 |
+
st.error("β No candidate context available. Please try again.")
|
324 |
+
return
|
325 |
+
|
326 |
+
# Retrieve the API key from the environment variable 'GroqAPI'
|
327 |
+
api_key = os.getenv("GroqAPI")
|
328 |
+
if not api_key:
|
329 |
+
st.error("β API key not found. Please ensure 'GroqAPI' is set in Hugging Face Secrets.")
|
330 |
+
return
|
331 |
+
|
332 |
+
payload = {
|
333 |
+
"model": "llama-3.3-70b-versatile", # Adjust model as needed.
|
334 |
+
"messages": [
|
335 |
+
{"role": "system", "content": f"Use the following context:<br>{st.session_state.candidate_context}"},
|
336 |
+
{"role": "user", "content": query}
|
337 |
+
]
|
338 |
+
}
|
339 |
+
url = "https://api.groq.com/openai/v1/chat/completions" # Verify this endpoint
|
340 |
+
|
341 |
+
headers = {
|
342 |
+
"Authorization": f"Bearer {api_key}",
|
343 |
+
"Content-Type": "application/json"
|
344 |
+
}
|
345 |
+
|
346 |
+
with st.spinner("π Fetching AI response..."):
|
347 |
+
try:
|
348 |
+
response = requests.post(url, headers=headers, json=payload)
|
349 |
+
response.raise_for_status() # Raises stored HTTPError, if one occurred.
|
350 |
+
response_json = response.json()
|
351 |
+
generated_text = response_json["choices"][0]["message"]["content"]
|
352 |
+
st.subheader("π AI Response:")
|
353 |
+
st.write(generated_text)
|
354 |
+
except requests.exceptions.HTTPError as http_err:
|
355 |
+
try:
|
356 |
+
error_info = response.json().get("error", {})
|
357 |
+
error_message = error_info.get("message", "An unknown error occurred.")
|
358 |
+
st.error(f"β HTTP error occurred: {error_message}")
|
359 |
+
except ValueError:
|
360 |
+
st.error(f"β HTTP error occurred: {response.status_code} - {response.text}")
|
361 |
+
except Exception as err:
|
362 |
+
st.error(f"β An unexpected error occurred: {err}")
|
363 |
+
|
364 |
def handle_user_query():
|
365 |
+
st.header("π¬ Let's Chat with Your Emails")
|
366 |
+
|
367 |
+
# Checkbox to show/hide the threshold slider
|
368 |
+
show_threshold = st.checkbox("Adjust Similarity Threshold")
|
369 |
+
|
370 |
+
# Slider, shown only if 'show_threshold' is True
|
371 |
+
if show_threshold:
|
372 |
+
similarity_threshold = st.slider(
|
373 |
+
"Select Similarity Threshold",
|
374 |
+
min_value=0.0,
|
375 |
+
max_value=1.0,
|
376 |
+
value=0.3,
|
377 |
+
step=0.05,
|
378 |
+
help="Adjust the similarity threshold to control the relevance of retrieved emails. Higher values yield more relevant results.",
|
379 |
+
key='similarity_threshold'
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
# Set a default threshold if the slider is not shown
|
383 |
+
if 'similarity_threshold' not in st.session_state:
|
384 |
+
st.session_state.similarity_threshold = 0.3
|
385 |
+
similarity_threshold = st.session_state.similarity_threshold
|
386 |
+
|
387 |
+
# Callback function to process the query
|
388 |
+
def query_callback():
|
389 |
+
query = st.session_state.query_input
|
390 |
+
if not query.strip():
|
391 |
return
|
392 |
+
process_candidate_emails(query, similarity_threshold)
|
393 |
+
if st.session_state.raw_candidates:
|
394 |
+
st.subheader("π Matching Email Chunks:")
|
395 |
+
for candidate, sim in st.session_state.raw_candidates:
|
396 |
+
# Get a snippet (first 150 characters) of the body instead of full body content.
|
397 |
+
body = candidate.get('body', 'No Content')
|
398 |
+
snippet = (body[:150] + "...") if len(body) > 150 else body
|
399 |
+
st.markdown(
|
400 |
+
f"**From:** {candidate.get('sender','Unknown')} <br>"
|
401 |
+
f"**To:** {candidate.get('to','Unknown')} <br>"
|
402 |
+
f"**Date:** {candidate.get('date','Unknown')} <br>"
|
403 |
+
f"**Subject:** {candidate.get('subject','No Subject')} <br>"
|
404 |
+
f"**Body Snippet:** {snippet} <br>"
|
405 |
+
f"**Similarity:** {sim:.4f}",
|
406 |
+
unsafe_allow_html=True
|
407 |
)
|
408 |
+
# Then send the query along with the context to the LLM API.
|
409 |
+
call_llm_api(query)
|
410 |
+
|
411 |
+
# Text input with callback on change (when Enter is pressed)
|
412 |
+
st.text_input("Enter your query:", key="query_input", on_change=query_callback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
# ===============================
|
415 |
# 6. Main Application Logic
|
416 |
# ===============================
|
417 |
def main():
|
418 |
st.sidebar.header("π Gmail Authentication")
|
419 |
+
credentials_file = st.sidebar.file_uploader("π Upload credentials.json", type=["json"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
+
data_management_option = st.sidebar.selectbox(
|
422 |
+
"Choose an option",
|
423 |
+
["Upload Pre-existing Data", "Authenticate and Create New Data"],
|
424 |
+
index=1 # Default to "Authenticate and Create New Data"
|
425 |
+
)
|
426 |
+
|
427 |
+
if data_management_option == "Upload Pre-existing Data":
|
428 |
+
uploaded_db = st.sidebar.file_uploader("π Upload vector database (vector_database.pkl)", type=["pkl"])
|
429 |
+
if uploaded_db:
|
430 |
+
# Check file size; if larger than 200MB, show a warning and then continue.
|
431 |
+
file_size_mb = uploaded_db.size / (1024 * 1024)
|
432 |
+
if file_size_mb > 200:
|
433 |
+
st.warning("The uploaded file is larger than 200MB. It may take longer to load, but processing will continue.")
|
434 |
+
try:
|
435 |
+
vector_db = pickle.load(uploaded_db)
|
436 |
+
st.session_state.vector_store = vector_db.get("vector_store")
|
437 |
+
st.session_state.embeddings = vector_db.get("embeddings")
|
438 |
+
st.session_state.data_chunks = vector_db.get("data_chunks")
|
439 |
+
st.success("π Vector database loaded successfully from upload!")
|
440 |
+
except Exception as e:
|
441 |
+
st.error(f"β Error loading vector database: {e}")
|
442 |
+
elif data_management_option == "Authenticate and Create New Data":
|
443 |
+
if credentials_file and st.sidebar.button("π Authenticate"):
|
444 |
+
reset_session_state()
|
445 |
+
with open("credentials.json", "wb") as f:
|
446 |
+
f.write(credentials_file.getbuffer())
|
447 |
+
authenticate_gmail("credentials.json")
|
448 |
|
449 |
+
if st.session_state.auth_url:
|
450 |
+
st.sidebar.markdown("### π **Authorization URL:**")
|
451 |
+
st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})")
|
452 |
+
st.sidebar.text_input("π Enter the authorization code:", key="auth_code")
|
453 |
+
if st.sidebar.button("β
Submit Authentication Code"):
|
454 |
+
submit_auth_code()
|
455 |
|
456 |
+
if data_management_option == "Authenticate and Create New Data" and st.session_state.authenticated:
|
457 |
st.sidebar.success("β
You are authenticated!")
|
458 |
+
st.header("π Data Management")
|
459 |
+
# Multi-select widget for folders (labels)
|
460 |
+
folders = st.multiselect("Select Labels (Folders) to Process Emails From:",
|
461 |
+
["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"], default=["INBOX"])
|
462 |
+
if st.button("π₯ Create Chunks and Embed Data"):
|
463 |
service = build('gmail', 'v1', credentials=st.session_state.creds)
|
464 |
+
all_chunks = []
|
465 |
+
# Process each selected folder
|
466 |
+
for folder in folders:
|
467 |
+
# Clear temporary data_chunks so that each folder's data is separate
|
468 |
+
st.session_state.data_chunks = []
|
469 |
+
create_chunks_from_gmail(service, folder)
|
470 |
+
if st.session_state.data_chunks:
|
471 |
+
all_chunks.extend(st.session_state.data_chunks)
|
472 |
+
st.session_state.data_chunks = all_chunks
|
473 |
if st.session_state.data_chunks:
|
474 |
embed_emails(st.session_state.data_chunks)
|
475 |
+
if st.session_state.vector_store is not None:
|
476 |
+
with st.expander("πΎ Download Data", expanded=True):
|
477 |
+
save_vector_database()
|
478 |
+
|
479 |
+
if st.session_state.vector_store is not None:
|
480 |
+
handle_user_query()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
481 |
|
482 |
if __name__ == "__main__":
|
483 |
main()
|