import streamlit as st
import os
import uuid
import shutil
from datetime import datetime, timedelta
from dotenv import load_dotenv
from chatMode import chat_response
from modules.pdfExtractor import PdfConverter
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
from sentence_transformers import SentenceTransformer
from modules.llm import GroqClient, GroqCompletion
import chromadb
import json

# Load environment variables
load_dotenv()

######## Embedding Model ########
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
embeddModel.max_seq_length = 512
chunk_size, chunk_overlap, top_k_default = 2000, 200, 5

######## Groq to LLM Connect ########
api_key = os.getenv("GROQ_API_KEY")
groq_client = GroqClient(api_key)
llm_model = {
    "Gemma9B": "gemma2-9b-it",
    "Gemma7B": "gemma-7b-it", 
    "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",       
    "LLama3.1-70B": "llama-3.1-70b-versatile",
    "LLama3-70B": "llama3-70b-8192",
    "LLama3.2-90B": "llama-3.2-90b-text-preview", 
    "Mixtral8x7B": "mixtral-8x7b-32768"
}
max_tokens = {
    "Gemma9B": 8192,
    "Gemma7B": 8192, 
    "LLama3-70B": 8192,       
    "LLama3.1-70B": 8000,
    "LLama3-70B": 8192,
    "LLama3.2-90B": 8192, 
    "Mixtral8x7B": 32768
}

## Time-based cleanup settings
EXPIRATION_TIME = timedelta(hours=6)
UPLOAD_DIR = "Uploaded"
VECTOR_DB_DIR = "vectorDB"
LOG_FILE = "upload_log.json"

## Initialize Streamlit app
st.set_page_config(page_title="ChatPDF", layout="wide")
st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)

## Function to log upload time
def log_upload_time(unique_id):
    upload_time = datetime.now().isoformat()
    log_entry = {unique_id: upload_time}
    if os.path.exists(LOG_FILE):
        with open(LOG_FILE, "r") as f:
            log_data = json.load(f)
        log_data.update(log_entry)
    else:
        log_data = log_entry

    with open(LOG_FILE, "w") as f:
        json.dump(log_data, f)

## Cleanup expired files based on log
def cleanup_expired_files():
    current_time = datetime.now()
    
    # Load upload log
    if os.path.exists(LOG_FILE):
        with open(LOG_FILE, "r") as f:
            log_data = json.load(f)
    
        keys_to_delete = []  # List to keep track of keys to delete
        # Check each entry in the log
        for unique_id, upload_time in log_data.items():
            upload_time_dt = datetime.fromisoformat(upload_time)
            if current_time - upload_time_dt > EXPIRATION_TIME:
                # Add key to the list for deletion
                keys_to_delete.append(unique_id)
                
                # Remove files if expired
                pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
                vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
                
                if os.path.isfile(pdf_file_path):
                    os.remove(pdf_file_path)
                if os.path.isdir(vector_db_path):
                    shutil.rmtree(vector_db_path)
        
        # Now delete the keys from log_data after iteration
        for key in keys_to_delete:
            del log_data[key]
        
        # Save updated log
        with open(LOG_FILE, "w") as f:
            json.dump(log_data, f)

## Context Taking, PDF Upload, and Mode Selection
with st.sidebar:
    st.title("Upload PDF:")
    
    research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
    option = ''

    if not research_field:
        st.info("Please enter a research field to proceed.")
        option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
        uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
    else:
        option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
        uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)

    temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
    selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
    top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)

## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
if 'db_client' not in st.session_state:
    unique_id = str(uuid.uuid4())
    st.session_state['unique_id'] = unique_id
    db_path = os.path.join(VECTOR_DB_DIR, unique_id)
    os.makedirs(db_path, exist_ok=True)
    st.session_state['db_path'] = db_path
    st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)

    # Log the upload time
    log_upload_time(unique_id)

# Access session-stored variables
db_client = st.session_state['db_client']
unique_id = st.session_state['unique_id']
db_path = st.session_state['db_path']

if 'document_text' not in st.session_state:
    st.session_state['document_text'] = None

if 'text_embeddings' not in st.session_state:
    st.session_state['text_embeddings'] = None

## Handle PDF Upload and Processing
if uploaded_file is not None and st.session_state['document_text'] is None:
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
    with open(file_path, "wb") as file:
        file.write(uploaded_file.getvalue())

    document_text = PdfConverter(file_path).convert_to_markdown()
    st.session_state['document_text'] = document_text
    
    text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
    text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
    st.session_state['text_embeddings'] = text_contents_embeddings

if st.session_state['document_text'] and st.session_state['text_embeddings']:
    document_text = st.session_state['document_text']
    text_contents_embeddings = st.session_state['text_embeddings']
else:
    st.stop()

q_input = st.chat_input(key="input", placeholder="Ask your question")

if q_input:
    if option == "Chat":
        query_embedding = ragQuery(embeddModel, q_input)
        top_k_matches = similarityChroma(query_embedding, db_client, top_k)

        LLMmodel = llm_model[selected_llm_model]
        domain = research_field
        prompt_template = q_input
        user_content = top_k_matches
        max_tokens = max_tokens[selected_llm_model]
        print(max_tokens)
        top_p = 1
        stream = True
        stop = None

        groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
        result = groq_completion.create_completion()
        
        with st.spinner("Processing..."):
            chat_response(q_input, result)

## Call the cleanup function periodically
cleanup_expired_files()