Spaces:
Sleeping
Sleeping
import streamlit as st | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import requests | |
import os | |
import torch | |
import pickle | |
import base64 | |
from googleapiclient.discovery import build | |
from google_auth_oauthlib.flow import InstalledAppFlow | |
from google.auth.transport.requests import Request | |
# =============================== | |
# 1. Streamlit App Configuration | |
# =============================== | |
st.set_page_config(page_title="π₯ Email Chat Application", layout="wide") | |
st.title("π¬ Turn Emails into ConversationsβEffortless Chat with Your Inbox! π©") | |
# =============================== | |
# 2. Initialize Session State Variables | |
# =============================== | |
if "authenticated" not in st.session_state: | |
st.session_state.authenticated = False | |
if "creds" not in st.session_state: | |
st.session_state.creds = None | |
if "auth_url" not in st.session_state: | |
st.session_state.auth_url = None | |
if "auth_code" not in st.session_state: | |
st.session_state.auth_code = "" | |
if "flow" not in st.session_state: | |
st.session_state.flow = None | |
if "data_chunks" not in st.session_state: | |
st.session_state.data_chunks = [] # List to store all email chunks | |
if "embeddings" not in st.session_state: | |
st.session_state.embeddings = None | |
if "vector_store" not in st.session_state: | |
st.session_state.vector_store = None | |
# For storing candidate context details. | |
if "candidate_context" not in st.session_state: | |
st.session_state.candidate_context = None | |
if "raw_candidates" not in st.session_state: | |
st.session_state.raw_candidates = None | |
# Initialize chat messages | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Flags to ensure success messages are shown only once | |
if "candidates_message_shown" not in st.session_state: | |
st.session_state.candidates_message_shown = False | |
if "vector_db_message_shown" not in st.session_state: | |
st.session_state.vector_db_message_shown = False | |
def count_tokens(text): | |
return len(text.split()) | |
# =============================== | |
# 3. Gmail Authentication Functions (Updated) | |
# =============================== | |
def reset_session_state(): | |
st.session_state.authenticated = False | |
st.session_state.creds = None | |
st.session_state.auth_url = None | |
st.session_state.auth_code = "" | |
st.session_state.flow = None | |
st.session_state.data_chunks = [] | |
st.session_state.embeddings = None | |
st.session_state.vector_store = None | |
st.session_state.candidate_context = None | |
st.session_state.raw_candidates = None | |
st.session_state.messages = [] | |
st.session_state.candidates_message_shown = False | |
st.session_state.vector_db_message_shown = False | |
for filename in ["token.json", "data_chunks.pkl", "embeddings.pkl", "vector_store.index", "vector_database.pkl"]: | |
if os.path.exists(filename): | |
os.remove(filename) | |
def authenticate_gmail(credentials_file): | |
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] | |
creds = None | |
if os.path.exists('token.json'): | |
try: | |
from google.oauth2.credentials import Credentials | |
creds = Credentials.from_authorized_user_file('token.json', SCOPES) | |
if creds and creds.valid: | |
st.session_state.creds = creds | |
st.session_state.authenticated = True | |
if not st.session_state.candidates_message_shown: | |
st.success("β Authentication successful!") | |
st.session_state.candidates_message_shown = True | |
return creds | |
except Exception as e: | |
st.error(f"β Invalid token.json file: {e}") | |
os.remove('token.json') | |
if not creds or not creds.valid: | |
if creds and creds.expired and creds.refresh_token: | |
creds.refresh(Request()) | |
st.session_state.creds = creds | |
st.session_state.authenticated = True | |
if not st.session_state.candidates_message_shown: | |
st.success("β Authentication successful!") | |
st.session_state.candidates_message_shown = True | |
with open('token.json', 'w') as token_file: | |
token_file.write(creds.to_json()) | |
return creds | |
else: | |
if not st.session_state.flow: | |
st.session_state.flow = InstalledAppFlow.from_client_secrets_file(credentials_file, SCOPES) | |
st.session_state.flow.redirect_uri = 'http://localhost' | |
auth_url, _ = st.session_state.flow.authorization_url(prompt='consent') | |
st.session_state.auth_url = auth_url | |
st.info("π **Authorize the application by visiting the URL below:**") | |
st.markdown(f"[Authorize]({st.session_state.auth_url})") | |
def submit_auth_code(): | |
try: | |
# Attempt to fetch the token using the provided authorization code | |
st.session_state.flow.fetch_token(code=st.session_state.auth_code) | |
st.session_state.creds = st.session_state.flow.credentials | |
# Attempt to write the credentials to token.json | |
with open('token.json', 'w') as token_file: | |
token_json = st.session_state.creds.to_json() | |
token_file.write(token_json) | |
# If writing is successful, update the session state | |
st.session_state.authenticated = True | |
st.success("β Authentication successful!") | |
except Exception as e: | |
# If any error occurs, ensure the authenticated flag is not set | |
st.session_state.authenticated = False | |
st.error(f"β Error during authentication: {e}") | |
# =============================== | |
# 4. Email Data Extraction, Embedding and Vector Store Functions | |
# =============================== | |
def extract_email_body(payload): | |
if 'body' in payload and 'data' in payload['body'] and payload['body']['data']: | |
try: | |
return base64.urlsafe_b64decode(payload['body']['data'].encode('UTF-8')).decode('UTF-8') | |
except Exception as e: | |
st.error(f"Error decoding email body: {e}") | |
return "" | |
if 'parts' in payload: | |
for part in payload['parts']: | |
if part.get('mimeType') == 'text/plain' and 'data' in part.get('body', {}): | |
try: | |
return base64.urlsafe_b64decode(part['body']['data'].encode('UTF-8')).decode('UTF-8') | |
except Exception as e: | |
st.error(f"Error decoding email part: {e}") | |
continue | |
if payload['parts']: | |
first_part = payload['parts'][0] | |
if 'data' in first_part.get('body', {}): | |
try: | |
return base64.urlsafe_b64decode(first_part['body']['data'].encode('UTF-8')).decode('UTF-8') | |
except Exception as e: | |
st.error(f"Error decoding fallback email part: {e}") | |
return "" | |
return "" | |
def combine_email_text(email): | |
# Build the complete email text by joining parts with HTML line breaks. | |
parts = [] | |
if email.get("sender"): | |
parts.append("From: " + email["sender"]) | |
if email.get("to"): | |
parts.append("To: " + email["to"]) | |
if email.get("date"): | |
parts.append("Date: " + email["date"]) | |
if email.get("subject"): | |
parts.append("Subject: " + email["subject"]) | |
if email.get("body"): | |
parts.append("Body: " + email["body"]) | |
return "<br>".join(parts) | |
def create_chunks_from_gmail(service, label): | |
try: | |
messages = [] | |
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500).execute() | |
messages.extend(result.get('messages', [])) | |
while 'nextPageToken' in result: | |
token = result["nextPageToken"] | |
result = service.users().messages().list(userId='me', labelIds=[label], maxResults=500, pageToken=token).execute() | |
messages.extend(result.get('messages', [])) | |
data_chunks = [] | |
progress_bar = st.progress(0) | |
total = len(messages) | |
for idx, msg in enumerate(messages): | |
msg_data = service.users().messages().get(userId='me', id=msg['id'], format='full').execute() | |
headers = msg_data.get('payload', {}).get('headers', []) | |
email_dict = {"id": msg['id']} | |
for header in headers: | |
name = header.get('name', '').lower() | |
if name == 'from': | |
email_dict['sender'] = header.get('value', '') | |
elif name == 'subject': | |
email_dict['subject'] = header.get('value', '') | |
elif name == 'to': | |
email_dict['to'] = header.get('value', '') | |
elif name == 'date': | |
email_dict['date'] = header.get('value', '') | |
email_dict['body'] = extract_email_body(msg_data.get('payload', {})) | |
data_chunks.append(email_dict) | |
progress_bar.progress(min((idx + 1) / total, 1.0)) | |
st.session_state.data_chunks.extend(data_chunks) | |
if not st.session_state.vector_db_message_shown: | |
st.success(f"π Vector database loaded successfully from upload! Total emails processed for label '{label}': {len(data_chunks)}") | |
st.session_state.vector_db_message_shown = True | |
except Exception as e: | |
st.error(f"β Error creating chunks from Gmail for label '{label}': {e}") | |
# ------------------------------- | |
# Cached model loaders for efficiency | |
# ------------------------------- | |
def get_embed_model(): | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
return model, device | |
def embed_emails(email_chunks): | |
st.header("π Embedding Data and Creating Vector Store") | |
progress_bar = st.progress(0) | |
with st.spinner('π Embedding data...'): | |
try: | |
embed_model, device = get_embed_model() | |
combined_texts = [combine_email_text(email) for email in email_chunks] | |
batch_size = 64 | |
embeddings = [] | |
for i in range(0, len(combined_texts), batch_size): | |
batch = combined_texts[i:i+batch_size] | |
batch_embeddings = embed_model.encode( | |
batch, | |
convert_to_numpy=True, | |
show_progress_bar=False, | |
device=device | |
) | |
embeddings.append(batch_embeddings) | |
progress_value = min((i + batch_size) / len(combined_texts), 1.0) | |
progress_bar.progress(progress_value) | |
embeddings = np.vstack(embeddings) | |
faiss.normalize_L2(embeddings) | |
st.session_state.embeddings = embeddings | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dimension) | |
index.add(embeddings) | |
st.session_state.vector_store = index | |
if not st.session_state.candidates_message_shown: | |
st.success("β Data embedding and vector store created successfully!") | |
st.session_state.candidates_message_shown = True | |
except Exception as e: | |
st.error(f"β Error during embedding: {e}") | |
# New function to save the entire vector database as a single pickle file. | |
def save_vector_database(): | |
try: | |
vector_db = { | |
"vector_store": st.session_state.vector_store, | |
"embeddings": st.session_state.embeddings, | |
"data_chunks": st.session_state.data_chunks | |
} | |
db_data = pickle.dumps(vector_db) | |
st.download_button( | |
label="πΎ Download Vector Database", | |
data=db_data, | |
file_name="vector_database.pkl", | |
mime="application/octet-stream" | |
) | |
except Exception as e: | |
st.error(f"β Error saving vector database: {e}") | |
# =============================== | |
# 5. Handling User Queries (User-Controlled Threshold) | |
# =============================== | |
def preprocess_query(query): | |
return query.lower().strip() | |
def process_candidate_emails(query, similarity_threshold): | |
""" | |
Process the query by computing its embedding, searching the vector store, | |
filtering candidates based on a similarity threshold, and building a context string. | |
""" | |
TOP_K = 20 # Increased to allow for threshold filtering | |
# Reset candidate context for each query | |
st.session_state.candidate_context = None | |
st.session_state.raw_candidates = None | |
if st.session_state.vector_store is None: | |
st.error("β Please process your email data or load a saved vector database first.") | |
return | |
try: | |
embed_model, device = get_embed_model() | |
processed_query = preprocess_query(query) | |
query_embedding = embed_model.encode( | |
[processed_query], | |
convert_to_numpy=True, | |
show_progress_bar=False, | |
device=device | |
) | |
faiss.normalize_L2(query_embedding) | |
# Perform search | |
distances, indices = st.session_state.vector_store.search(query_embedding, TOP_K) | |
candidates = [] | |
for idx, sim in zip(indices[0], distances[0]): | |
# Include candidate only if similarity meets the threshold | |
if sim >= similarity_threshold: | |
candidates.append((st.session_state.data_chunks[idx], sim)) | |
if not candidates: | |
# Append warning message as assistant message | |
st.session_state.messages.append({"role": "assistant", "content": "β οΈ No matching embeddings found for your query with the selected threshold."}) | |
return | |
# Build the context string by concatenating all matching email texts using HTML breaks. | |
context_str = "" | |
for candidate, sim in candidates: | |
context_str += combine_email_text(candidate) + "<br><br>" | |
# Optionally limit context size. | |
MAX_CONTEXT_TOKENS = 500 | |
context_tokens = context_str.split() | |
if len(context_tokens) > MAX_CONTEXT_TOKENS: | |
context_str = " ".join(context_tokens[:MAX_CONTEXT_TOKENS]) | |
st.session_state.candidate_context = context_str | |
st.session_state.raw_candidates = candidates | |
except Exception as e: | |
st.error(f"β An error occurred during processing: {e}") | |
def call_llm_api(query): | |
""" | |
Send the user's query along with the concatenated matching email texts (context) | |
to the LLM API and display the AI response. | |
""" | |
if not st.session_state.candidate_context: | |
st.error("β No candidate context available. Please try again.") | |
return | |
# Retrieve the API key from the environment variable 'GroqAPI' | |
api_key = os.getenv("GroqAPI") | |
if not api_key: | |
st.error("β API key not found. Please ensure 'GroqAPI' is set in Hugging Face Secrets.") | |
return | |
payload = { | |
"model": "llama-3.3-70b-versatile", # Adjust model as needed. | |
"messages": [ | |
{"role": "system", "content": f"Use the following context:\n{st.session_state.candidate_context}"}, | |
{"role": "user", "content": query} | |
] | |
} | |
url = "https://api.groq.com/openai/v1/chat/completions" # Verify this endpoint | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
try: | |
response = requests.post(url, headers=headers, json=payload) | |
response.raise_for_status() # Raises stored HTTPError, if one occurred. | |
response_json = response.json() | |
generated_text = response_json["choices"][0]["message"]["content"] | |
# Append AI response to chat messages | |
st.session_state.messages.append({"role": "assistant", "content": generated_text}) | |
except requests.exceptions.HTTPError as http_err: | |
try: | |
error_info = response.json().get("error", {}) | |
error_message = error_info.get("message", "An unknown error occurred.") | |
st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {error_message}"}) | |
except ValueError: | |
st.session_state.messages.append({"role": "assistant", "content": f"β HTTP error occurred: {response.status_code} - {response.text}"}) | |
except Exception as err: | |
st.session_state.messages.append({"role": "assistant", "content": f"β An unexpected error occurred: {err}"}) | |
def handle_user_query(): | |
st.header("π¬ Let's Chat with Your Emails") | |
# Expander for threshold selection | |
with st.expander("π§ Adjust Similarity Threshold", expanded=False): | |
similarity_threshold = st.slider( | |
"Select Similarity Threshold", | |
min_value=0.0, | |
max_value=1.0, | |
value=0.3, | |
step=0.05, | |
help="Adjust the similarity threshold to control the relevance of retrieved emails. Higher values yield more relevant results.", | |
key='similarity_threshold' | |
) | |
# Chat input for user queries | |
user_input = st.chat_input("Enter your query:") | |
if user_input: | |
# Append user message to chat | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Process the query | |
process_candidate_emails(user_input, similarity_threshold) | |
if st.session_state.candidate_context: | |
# Send the query to the LLM API | |
call_llm_api(user_input) | |
# Display chat messages | |
for msg in st.session_state.messages: | |
if msg["role"] == "user": | |
with st.chat_message("user"): | |
st.markdown(msg["content"]) | |
elif msg["role"] == "assistant": | |
with st.chat_message("assistant"): | |
st.markdown(msg["content"]) | |
# Display matching email chunks in an expander | |
if st.session_state.raw_candidates: | |
with st.expander("π Matching Email Chunks:", expanded=False): | |
for candidate, sim in st.session_state.raw_candidates: | |
# Get a snippet (first 150 characters) of the body instead of full body content. | |
body = candidate.get('body', 'No Content') | |
snippet = (body[:150] + "...") if len(body) > 150 else body | |
st.markdown( | |
f"**From:** {candidate.get('sender','Unknown')} <br>" | |
f"**To:** {candidate.get('to','Unknown')} <br>" | |
f"**Date:** {candidate.get('date','Unknown')} <br>" | |
f"**Subject:** {candidate.get('subject','No Subject')} <br>" | |
f"**Body Snippet:** {snippet} <br>" | |
f"**Similarity:** {sim:.4f}", | |
unsafe_allow_html=True | |
) | |
# =============================== | |
# 6. Main Application Logic | |
# =============================== | |
def main(): | |
SCOPES = ['https://www.googleapis.com/auth/gmail.readonly'] | |
st.sidebar.header("π Gmail Authentication") | |
credentials_file = st.sidebar.file_uploader("π Upload credentials.json", type=["json"]) | |
data_management_option = st.sidebar.selectbox( | |
"Choose an option", | |
["Upload Pre-existing Data", "Authenticate and Create New Data"], | |
index=1 # Default to "Authenticate and Create New Data" | |
) | |
if data_management_option == "Upload Pre-existing Data": | |
uploaded_db = st.sidebar.file_uploader("π Upload vector database (vector_database.pkl)", type=["pkl"]) | |
if uploaded_db: | |
# Check file size; if larger than 200MB, show a warning and then continue. | |
file_size_mb = uploaded_db.size / (1024 * 1024) | |
if file_size_mb > 200: | |
st.warning("β οΈ The uploaded file is larger than 200MB. It may take longer to load, but processing will continue.") | |
try: | |
vector_db = pickle.load(uploaded_db) | |
st.session_state.vector_store = vector_db.get("vector_store") | |
st.session_state.embeddings = vector_db.get("embeddings") | |
st.session_state.data_chunks = vector_db.get("data_chunks") | |
if not st.session_state.vector_db_message_shown: | |
st.success("π Vector database loaded successfully from upload!") | |
st.session_state.vector_db_message_shown = True | |
except Exception as e: | |
st.error(f"β Error loading vector database: {e}") | |
elif data_management_option == "Authenticate and Create New Data": | |
if credentials_file and st.sidebar.button("π Authenticate"): | |
reset_session_state() | |
with open("credentials.json", "wb") as f: | |
f.write(credentials_file.getbuffer()) | |
authenticate_gmail("credentials.json") | |
if st.session_state.auth_url: | |
st.sidebar.markdown("### π **Authorization URL:**") | |
st.sidebar.markdown(f"[Authorize]({st.session_state.auth_url})") | |
st.sidebar.text_input("π Enter the authorization code:", key="auth_code") | |
if st.sidebar.button("β Submit Authentication Code"): | |
submit_auth_code() | |
if data_management_option == "Authenticate and Create New Data" and st.session_state.authenticated: | |
st.sidebar.success("β You are authenticated!") | |
st.header("π Data Management") | |
# Multi-select widget for folders (labels) | |
folders = st.multiselect("Select Labels (Folders) to Process Emails From:", | |
["INBOX", "SENT", "DRAFTS", "TRASH", "SPAM"], default=["INBOX"]) | |
if st.button("π₯ Create Chunks and Embed Data"): | |
service = build('gmail', 'v1', credentials=st.session_state.creds) | |
all_chunks = [] | |
# Process each selected folder | |
for folder in folders: | |
# Clear temporary data_chunks so that each folder's data is separate | |
st.session_state.data_chunks = [] | |
create_chunks_from_gmail(service, folder) | |
if st.session_state.data_chunks: | |
all_chunks.extend(st.session_state.data_chunks) | |
st.session_state.data_chunks = all_chunks | |
if st.session_state.data_chunks: | |
embed_emails(st.session_state.data_chunks) | |
if st.session_state.vector_store is not None: | |
with st.expander("πΎ Download Data", expanded=False): | |
save_vector_database() | |
if st.session_state.vector_store is not None: | |
handle_user_query() | |
if __name__ == "__main__": | |
main() | |