import streamlit as st import os import uuid import shutil from datetime import datetime, timedelta from dotenv import load_dotenv from chatMode import chat_response from modules.pdfExtractor import PdfConverter from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma from sentence_transformers import SentenceTransformer from modules.llm import GroqClient, GroqCompletion import chromadb import json # Load environment variables load_dotenv() ######## Embedding Model ######## embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel")) embeddModel.max_seq_length = 512 chunk_size, chunk_overlap, top_k_default = 2000, 200, 5 ######## Groq to LLM Connect ######## api_key = os.getenv("GROQ_API_KEY") groq_client = GroqClient(api_key) llm_model = { "Gemma9B": "gemma2-9b-it", "Gemma7B": "gemma-7b-it", "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview", "LLama3.1-70B": "llama-3.1-70b-versatile", "LLama3-70B": "llama3-70b-8192", "LLama3.2-90B": "llama-3.2-90b-text-preview", "Mixtral8x7B": "mixtral-8x7b-32768" } max_tokens = { "Gemma9B": 8192, "Gemma7B": 8192, "LLama3-70B": 8192, "LLama3.1-70B": 8000, "LLama3-70B": 8192, "LLama3.2-90B": 8192, "Mixtral8x7B": 32768 } ## Time-based cleanup settings EXPIRATION_TIME = timedelta(hours=6) UPLOAD_DIR = "Uploaded" VECTOR_DB_DIR = "vectorDB" LOG_FILE = "upload_log.json" ## Initialize Streamlit app st.set_page_config(page_title="ChatPDF", layout="wide") st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True) ## Function to log upload time def log_upload_time(unique_id): upload_time = datetime.now().isoformat() log_entry = {unique_id: upload_time} if os.path.exists(LOG_FILE): with open(LOG_FILE, "r") as f: log_data = json.load(f) log_data.update(log_entry) else: log_data = log_entry with open(LOG_FILE, "w") as f: json.dump(log_data, f) ## Cleanup expired files based on log def cleanup_expired_files(): current_time = datetime.now() # Load upload log if os.path.exists(LOG_FILE): with open(LOG_FILE, "r") as f: log_data = json.load(f) keys_to_delete = [] # List to keep track of keys to delete # Check each entry in the log for unique_id, upload_time in log_data.items(): upload_time_dt = datetime.fromisoformat(upload_time) if current_time - upload_time_dt > EXPIRATION_TIME: # Add key to the list for deletion keys_to_delete.append(unique_id) # Remove files if expired pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id) if os.path.isfile(pdf_file_path): os.remove(pdf_file_path) if os.path.isdir(vector_db_path): shutil.rmtree(vector_db_path) # Now delete the keys from log_data after iteration for key in keys_to_delete: del log_data[key] # Save updated log with open(LOG_FILE, "w") as f: json.dump(log_data, f) ## Context Taking, PDF Upload, and Mode Selection with st.sidebar: st.title("Upload PDF:") research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas") option = '' if not research_field: st.info("Please enter a research field to proceed.") option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True) uploaded_file = st.file_uploader("", type=["pdf"], disabled=True) else: option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting')) uploaded_file = st.file_uploader("", type=["pdf"], disabled=False) temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01) selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3) top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5) ## Initialize unique ID, db_client, db_path, and timestamp if not already in session state if 'db_client' not in st.session_state: unique_id = str(uuid.uuid4()) st.session_state['unique_id'] = unique_id db_path = os.path.join(VECTOR_DB_DIR, unique_id) os.makedirs(db_path, exist_ok=True) st.session_state['db_path'] = db_path st.session_state['db_client'] = chromadb.PersistentClient(path=db_path) # Log the upload time log_upload_time(unique_id) # Access session-stored variables db_client = st.session_state['db_client'] unique_id = st.session_state['unique_id'] db_path = st.session_state['db_path'] if 'document_text' not in st.session_state: st.session_state['document_text'] = None if 'text_embeddings' not in st.session_state: st.session_state['text_embeddings'] = None ## Handle PDF Upload and Processing if uploaded_file is not None and st.session_state['document_text'] is None: os.makedirs(UPLOAD_DIR, exist_ok=True) file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf") with open(file_path, "wb") as file: file.write(uploaded_file.getvalue()) document_text = PdfConverter(file_path).convert_to_markdown() st.session_state['document_text'] = document_text text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap) text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path) st.session_state['text_embeddings'] = text_contents_embeddings if st.session_state['document_text'] and st.session_state['text_embeddings']: document_text = st.session_state['document_text'] text_contents_embeddings = st.session_state['text_embeddings'] else: st.stop() q_input = st.chat_input(key="input", placeholder="Ask your question") if q_input: if option == "Chat": query_embedding = ragQuery(embeddModel, q_input) top_k_matches = similarityChroma(query_embedding, db_client, top_k) LLMmodel = llm_model[selected_llm_model] domain = research_field prompt_template = q_input user_content = top_k_matches max_tokens = max_tokens[selected_llm_model] print(max_tokens) top_p = 1 stream = True stop = None groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop) result = groq_completion.create_completion() with st.spinner("Processing..."): chat_response(q_input, result) ## Call the cleanup function periodically cleanup_expired_files()