Spaces:
Running
Running
File size: 22,551 Bytes
6170ca9 dfa3a2c 08f287a dfa3a2c 6170ca9 08f287a 6170ca9 dfa3a2c 08f287a 6170ca9 872e099 73f075b 144b99c 73f075b dfa3a2c 73f075b 08f287a 71416c7 872e099 dfa3a2c 3df78c0 dfa3a2c 08f287a 71416c7 08f287a dfa3a2c 3df78c0 dfa3a2c 73f075b 6170ca9 08f287a 73f075b 6c44683 160f1c7 6170ca9 dfa3a2c 73f075b dfa3a2c 73f075b 6c44683 dfa3a2c 2cf78f9 73f075b dfa3a2c 6a1fae7 144b99c 6c44683 144b99c 6c44683 144b99c 6c44683 3df78c0 144b99c 6c44683 dfa3a2c 872e099 dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a 872e099 e4652f2 71416c7 dfa3a2c e4652f2 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c e4652f2 dfa3a2c 08f287a dfa3a2c 08f287a 81ba234 08f287a dfa3a2c 08f287a dfa3a2c 872e099 dfa3a2c 08f287a e4652f2 08f287a 872e099 08f287a e4652f2 08f287a e4652f2 08f287a dfa3a2c 08f287a 81ba234 08f287a 71416c7 e4652f2 71416c7 e4652f2 71416c7 8d8049a 71416c7 e4652f2 8d8049a dfa3a2c 872e099 dfa3a2c 872e099 dfa3a2c 08f287a dfa3a2c 08f287a 81ba234 08f287a 872e099 71416c7 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a dfa3a2c 08f287a 81ba234 08f287a dfa3a2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 |
import streamlit as st
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import requests
import os
import torch
import pickle
import base64
from googleapiclient.discovery import build
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.transport.requests import Request
# ===============================
# 1. Streamlit App Configuration
# ===============================
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide")
st.title("π¬ Turn Emails into ConversationsβEffortless Chat with Your Inbox! π©")
# ===============================
# 2. Initialize Session State Variables
# ===============================
if "authenticated" not in st.session_state:
st.session_state.authenticated = False
if "creds" not in st.session_state:
st.session_state.creds = None
if "auth_url" not in st.session_state:
st.session_state.auth_url = None
if "auth_code" not in st.session_state:
st.session_state.auth_code = ""
if "flow" not in st.session_state:
st.session_state.flow = None
if "data_chunks" not in st.session_state:
st.session_state.data_chunks = [] # List to store all email chunks
if "embeddings" not in st.session_state:
st.session_state.embeddings = None
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
# For storing candidate context details.
if "candidate_context" not in st.session_state:
st.session_state.candidate_context = None
if "raw_candidates" not in st.session_state:
st.session_state.raw_candidates = None
# Initialize chat messages
if "messages" not in st.session_state:
st.session_state.messages = []
# Flags to ensure success messages are shown only once
if "candidates_message_shown" not in st.session_state:
st.session_state.candidates_message_shown = False
if "vector_db_message_shown" not in st.session_state:
st.session_state.vector_db_message_shown = False
def count_tokens(text):
return len(text.split())
# ===============================
# 3. Gmail Authentication Functions (Updated)
# ===============================
def reset_session_state():
st.session_state.authenticated = False
st.session_state.creds = None
st.session_state.auth_url = None
st.session_state.auth_code = ""
st.session_state.flow = None
st.session_state.data_chunks = []
st.session_state.embeddings = None
st.session_state.vector_store = None
st.session_state.candidate_context = None
st.session_state.raw_candidates = None
st.session_state.messages = []
st.session_state.candidates_message_shown = False
st.session_state.vector_db_message_shown = False
for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index", "vector_database.pkl"]:
if os.path.exists(filename):
os.remove(filename)
def authenticate_gmail(credentials_file):
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']
creds = None
if os.path.exists('token.json'):
try:
from google.oauth2.credentials import Credentials
creds = Credentials.from_authorized_user_file('token.json', SCOPES)
if creds and creds.valid:
st.session_state.creds = creds
st.session_state.authenticated = True
if not st.session_state.candidates_message_shown:
st.success("β
Authentication successful!")
st.session_state.candidates_message_shown = True
return creds
except Exception as e:
st.error(f"β Invalid token.json file: {e}")
os.remove('token.json')
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
st.session_state.creds = creds
st.session_state.authenticated = True
if not st.session_state.candidates_message_shown:
st.success("β
Authentication successful!")
st.session_state.candidates_message_shown = True
with open('token.json', 'w') as token_file:
token_file.write(creds.to_json())
return creds
else:
if not st.session_state.flow:
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES)
st.session_state.flow.redirect_uri = 'http://localhost'
auth_url, _ = st.session_state.flow.authorization_url(prompt='consent')
st.session_state.auth_url = auth_url
st.info("π **Authorize the application by visiting the URL below:**")
st.markdown(f"[Authorize]({st.session_state.auth_url})")
def submit_auth_code():
try:
# Attempt to fetch the token using the provided authorization code
st.session_state.flow.fetch_token(code=st.session_state.auth_code)
st.session_state.creds = st.session_state.flow.credentials
# Attempt to write the credentials to token.json
with open('token.json', 'w') as token_file:
token_json = st.session_state.creds.to_json()
token_file.write(token_json)
# If writing is successful, update the session state
st.session_state.authenticated = True
st.success("β
Authentication successful!")
except Exception as e:
# If any error occurs, ensure the authenticated flag is not set
st.session_state.authenticated = False
st.error(f"β Error during authentication: {e}")
# ===============================
# 4. Email Data Extraction, Embedding and Vector Store Functions
# ===============================
def extract_email_body(payload):
if 'body' in payload and 'data' in payload['body'] and payload['body']['data']:
try:
return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8')
except Exception as e:
st.error(f"Error decoding email body: {e}")
return ""
if 'parts' in payload:
for part in payload['parts']:
if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}):
try:
return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8')
except Exception as e:
st.error(f"Error decoding email part: {e}")
continue
if payload['parts']:
first_part = payload['parts'][0]
if 'data' in first_part.get('body', {}):
try:
return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8')
except Exception as e:
st.error(f"Error decoding fallback email part: {e}")
return ""
return ""
def combine_email_text(email):
# Build the complete email text by joining parts with HTML line breaks.
parts = []
if email.get("sender"):
parts.append("From: " + email["sender"])
if email.get("to"):
parts.append("To: " + email["to"])
if email.get("date"):
parts.append("Date: " + email["date"])
if email.get("subject"):
parts.append("Subject: " + email["subject"])
if email.get("body"):
parts.append("Body: " + email["body"])
return "<br>".join(parts)
def create_chunks_from_gmail(service, label):
try:
messages = []
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute()
messages.extend(result.get('messages', []))
while 'nextPageToken' in result:
token = result["nextPageToken"]
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500, pageToken=token).execute()
messages.extend(result.get('messages', []))
data_chunks = []
progress_bar = st.progress(0)
total = len(messages)
for idx, msg in enumerate(messages):
msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute()
headers = msg_data.get('payload', {}).get('headers', [])
email_dict = {"id": msg['id']}
for header in headers:
name = header.get('name', '').lower()
if name == 'from':
email_dict['sender'] = header.get('value', '')
elif name == 'subject':
email_dict['subject'] = header.get('value', '')
elif name == 'to':
email_dict['to'] = header.get('value', '')
elif name == 'date':
email_dict['date'] = header.get('value', '')
email_dict['body'] = extract_email_body(msg_data.get('payload', {}))
data_chunks.append(email_dict)
progress_bar.progress(min((idx + 1) / total, 1.0))
st.session_state.data_chunks.extend(data_chunks)
if not st.session_state.vector_db_message_shown:
st.success(f"π Vector database loaded successfully from upload! Total emails processed for label '{label}': {len(data_chunks)}")
st.session_state.vector_db_message_shown = True
except Exception as e:
st.error(f"β Error creating chunks from Gmail for label '{label}': {e}")
# -------------------------------
# Cached model loaders for efficiency
# -------------------------------
@st.cache_resource
def get_embed_model():
model = SentenceTransformer("all-MiniLM-L6-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, device
def embed_emails(email_chunks):
st.header("π Embedding Data and Creating Vector Store")
progress_bar = st.progress(0)
with st.spinner('π Embedding data...'):
try:
embed_model, device = get_embed_model()
combined_texts = [combine_email_text(email) for email in email_chunks]
batch_size = 64
embeddings = []
for i in range(0, len(combined_texts), batch_size):
batch = combined_texts[i:i+batch_size]
batch_embeddings = embed_model.encode(
batch,
convert_to_numpy=True,
show_progress_bar=False,
device=device
)
embeddings.append(batch_embeddings)
progress_value = min((i + batch_size) / len(combined_texts), 1.0)
progress_bar.progress(progress_value)
embeddings = np.vstack(embeddings)
faiss.normalize_L2(embeddings)
st.session_state.embeddings = embeddings
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(embeddings)
st.session_state.vector_store = index
if not st.session_state.candidates_message_shown:
st.success("β
Data embedding and vector store created successfully!")
st.session_state.candidates_message_shown = True
except Exception as e:
st.error(f"β Error during embedding: {e}")
# New function to save the entire vector database as a single pickle file.
def save_vector_database():
try:
vector_db = {
"vector_store": st.session_state.vector_store,
"embeddings": st.session_state.embeddings,
"data_chunks": st.session_state.data_chunks
}
db_data = pickle.dumps(vector_db)
st.download_button(
label="πΎ Download Vector Database",
data=db_data,
file_name="vector_database.pkl",
mime="application/octet-stream"
)
except Exception as e:
st.error(f"β Error saving vector database: {e}")
# ===============================
# 5. Handling User Queries (User-Controlled Threshold)
# ===============================
def preprocess_query(query):
return query.lower().strip()
def process_candidate_emails(query, similarity_threshold):
"""
Process the query by computing its embedding, searching the vector store,
filtering candidates based on a similarity threshold, and building a context string.
"""
TOP_K = 20 # Increased to allow for threshold filtering
# Reset candidate context for each query
st.session_state.candidate_context = None
st.session_state.raw_candidates = None
if st.session_state.vector_store is None:
st.error("β Please process your email data or load a saved vector database first.")
return
try:
embed_model, device = get_embed_model()
processed_query = preprocess_query(query)
query_embedding = embed_model.encode(
[processed_query],
convert_to_numpy=True,
show_progress_bar=False,
device=device
)
faiss.normalize_L2(query_embedding)
# Perform search
distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K)
candidates = []
for idx, sim in zip(indices[0], distances[0]):
# Include candidate only if similarity meets the threshold
if sim >= similarity_threshold:
candidates.append((st.session_state.data_chunks[idx], sim))
if not candidates:
# Append warning message as assistant message
st.session_state.messages.append({"role": "assistant", "content": "β οΈ No matching embeddings found for your query with the selected threshold."})
return
# Build the context string by concatenating all matching email texts using HTML breaks.
context_str = ""
for candidate, sim in candidates:
context_str += combine_email_text(candidate) + "<br><br>"
# Optionally limit context size.
MAX_CONTEXT_TOKENS = 500
context_tokens = context_str.split()
if len(context_tokens) > MAX_CONTEXT_TOKENS:
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS])
st.session_state.candidate_context = context_str
st.session_state.raw_candidates = candidates
except Exception as e:
st.error(f"β An error occurred during processing: {e}")
def call_llm_api(query):
"""
Send the user's query along with the concatenated matching email texts (context)
to the LLM API and display the AI response.
"""
if not st.session_state.candidate_context:
st.error("β No candidate context available. Please try again.")
return
# Retrieve the API key from the environment variable 'GroqAPI'
api_key = os.getenv("GroqAPI")
if not api_key:
st.error("β API key not found. Please ensure 'GroqAPI' is set in Hugging Face Secrets.")
return
payload = {
"model": "llama-3.3-70b-versatile", # Adjust model as needed.
"messages": [
{"role": "system", "content": f"Use the following context:\n{st.session_state.candidate_context}"},
{"role": "user", "content": query}
]
}
url = "https://api.groq.com/openai/v1/chat/completions" # Verify this endpoint
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
try:
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status() # Raises stored HTTPError, if one occurred.
response_json = response.json()
generated_text = response_json["choices"][0]["message"]["content"]
# Append AI response to chat messages
st.session_state.messages.append({"role": "assistant", "content": generated_text})
except requests.exceptions.HTTPError as http_err:
try:
error_info = response.json().get("error", {})
error_message = error_info.get("message", "An unknown error occurred.")
st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {error_message}"})
except ValueError:
st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {response.status_code} - {response.text}"})
except Exception as err:
st.session_state.messages.append({"role": "assistant", "content": f"β An unexpected error occurred: {err}"})
def handle_user_query():
st.header("π¬ Let's Chat with Your Emails")
# Expander for threshold selection
with st.expander("π§ Adjust Similarity Threshold", expanded=False):
similarity_threshold = st.slider(
"Select Similarity Threshold",
min_value=0.0,
max_value=1.0,
value=0.3,
step=0.05,
help="Adjust the similarity threshold to control the relevance of retrieved emails. Higher values yield more relevant results.",
key='similarity_threshold'
)
# Chat input for user queries
user_input = st.chat_input("Enter your query:")
if user_input:
# Append user message to chat
st.session_state.messages.append({"role": "user", "content": user_input})
# Process the query
process_candidate_emails(user_input, similarity_threshold)
if st.session_state.candidate_context:
# Send the query to the LLM API
call_llm_api(user_input)
# Display chat messages
for msg in st.session_state.messages:
if msg["role"] == "user":
with st.chat_message("user"):
st.markdown(msg["content"])
elif msg["role"] == "assistant":
with st.chat_message("assistant"):
st.markdown(msg["content"])
# Display matching email chunks in an expander
if st.session_state.raw_candidates:
with st.expander("π Matching Email Chunks:", expanded=False):
for candidate, sim in st.session_state.raw_candidates:
# Get a snippet (first 150 characters) of the body instead of full body content.
body = candidate.get('body', 'No Content')
snippet = (body[:150] + "...") if len(body) > 150 else body
st.markdown(
f"**From:** {candidate.get('sender','Unknown')} <br>"
f"**To:** {candidate.get('to','Unknown')} <br>"
f"**Date:** {candidate.get('date','Unknown')} <br>"
f"**Subject:** {candidate.get('subject','No Subject')} <br>"
f"**Body Snippet:** {snippet} <br>"
f"**Similarity:** {sim:.4f}",
unsafe_allow_html=True
)
# ===============================
# 6. Main Application Logic
# ===============================
def main():
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly']
st.sidebar.header("π Gmail Authentication")
credentials_file = st.sidebar.file_uploader("π Upload credentials.json", type=["json"])
data_management_option = st.sidebar.selectbox(
"Choose an option",
["Upload Pre-existing Data", "Authenticate and Create New Data"],
index=1 # Default to "Authenticate and Create New Data"
)
if data_management_option == "Upload Pre-existing Data":
uploaded_db = st.sidebar.file_uploader("π Upload vector database (vector_database.pkl)", type=["pkl"])
if uploaded_db:
# Check file size; if larger than 200MB, show a warning and then continue.
file_size_mb = uploaded_db.size / (1024 * 1024)
if file_size_mb > 200:
st.warning("β οΈ The uploaded file is larger than 200MB. It may take longer to load, but processing will continue.")
try:
vector_db = pickle.load(uploaded_db)
st.session_state.vector_store = vector_db.get("vector_store")
st.session_state.embeddings = vector_db.get("embeddings")
st.session_state.data_chunks = vector_db.get("data_chunks")
if not st.session_state.vector_db_message_shown:
st.success("π Vector database loaded successfully from upload!")
st.session_state.vector_db_message_shown = True
except Exception as e:
st.error(f"β Error loading vector database: {e}")
elif data_management_option == "Authenticate and Create New Data":
if credentials_file and st.sidebar.button("π Authenticate"):
reset_session_state()
with open("credentials.json", "wb") as f:
f.write(credentials_file.getbuffer())
authenticate_gmail("credentials.json")
if st.session_state.auth_url:
st.sidebar.markdown("### π **Authorization URL:**")
st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})")
st.sidebar.text_input("π Enter the authorization code:", key="auth_code")
if st.sidebar.button("β
Submit Authentication Code"):
submit_auth_code()
if data_management_option == "Authenticate and Create New Data" and st.session_state.authenticated:
st.sidebar.success("β
You are authenticated!")
st.header("π Data Management")
# Multi-select widget for folders (labels)
folders = st.multiselect("Select Labels (Folders) to Process Emails From:",
["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"], default=["INBOX"])
if st.button("π₯ Create Chunks and Embed Data"):
service = build('gmail', 'v1', credentials=st.session_state.creds)
all_chunks = []
# Process each selected folder
for folder in folders:
# Clear temporary data_chunks so that each folder's data is separate
st.session_state.data_chunks = []
create_chunks_from_gmail(service, folder)
if st.session_state.data_chunks:
all_chunks.extend(st.session_state.data_chunks)
st.session_state.data_chunks = all_chunks
if st.session_state.data_chunks:
embed_emails(st.session_state.data_chunks)
if st.session_state.vector_store is not None:
with st.expander("πΎ Download Data", expanded=False):
save_vector_database()
if st.session_state.vector_store is not None:
handle_user_query()
if __name__ == "__main__":
main()
|