File size: 7,669 Bytes
e59d3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331f35b
 
e59d3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90ec342
 
 
 
 
e59d3be
90ec342
 
 
 
 
 
 
 
e59d3be
90ec342
 
e59d3be
 
 
 
 
90ec342
 
e59d3be
 
 
 
 
 
 
 
 
 
 
90ec342
 
 
 
e59d3be
90ec342
 
e59d3be
90ec342
 
e59d3be
 
90ec342
e59d3be
 
 
 
 
 
 
 
 
 
 
 
90ec342
e59d3be
 
 
90ec342
 
e59d3be
90ec342
e59d3be
 
 
90ec342
e59d3be
 
 
 
 
 
 
 
 
90ec342
 
 
 
 
 
 
 
 
 
 
 
 
 
e59d3be
 
 
 
 
90ec342
e59d3be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import streamlit as st
import os
import uuid
import shutil
from datetime import datetime, timedelta
from dotenv import load_dotenv
from chatMode import chat_response
from modules.pdfExtractor import PdfConverter
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
from sentence_transformers import SentenceTransformer
from modules.llm import GroqClient, GroqCompletion
import chromadb
import json

# Load environment variables
load_dotenv()

######## Embedding Model ########
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
embeddModel.max_seq_length = 512
chunk_size, chunk_overlap, top_k_default = 1000, 300, 5

######## Groq to LLM Connect ########
api_key = os.getenv("GROQ_API_KEY")
groq_client = GroqClient(api_key)
llm_model = {
    "Gemma9B": "gemma2-9b-it",
    "Gemma7B": "gemma-7b-it", 
    "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",       
    "LLama3.1-70B": "llama-3.1-70b-versatile",
    "LLama3-70B": "llama3-70b-8192",
    "LLama3.2-90B": "llama-3.2-90b-text-preview", 
    "Mixtral8x7B": "mixtral-8x7b-32768"
}
max_tokens = {
    "Gemma9B": 8192,
    "Gemma7B": 8192, 
    "LLama3-70B": 8192,       
    "LLama3.1-70B": 8000,
    "LLama3-70B": 8192,
    "LLama3.2-90B": 8192, 
    "Mixtral8x7B": 32768
}

## Time-based cleanup settings
EXPIRATION_TIME = timedelta(hours=6)
UPLOAD_DIR = "Uploaded"
VECTOR_DB_DIR = "vectorDB"
LOG_FILE = "upload_log.json"

## Initialize Streamlit app
st.set_page_config(page_title="Ospyn AI", layout="wide")
st.markdown("<h2 style='text-align: center;'>Ospyn AI</h2>", unsafe_allow_html=True)

## Function to log upload time
def log_upload_time(unique_id):
    upload_time = datetime.now().isoformat()
    log_entry = {unique_id: upload_time}
    if os.path.exists(LOG_FILE):
        with open(LOG_FILE, "r") as f:
            log_data = json.load(f)
        log_data.update(log_entry)
    else:
        log_data = log_entry

    with open(LOG_FILE, "w") as f:
        json.dump(log_data, f)

## Cleanup expired files based on log
def cleanup_expired_files():
    current_time = datetime.now()
    
    # Load upload log
    if os.path.exists(LOG_FILE):
        with open(LOG_FILE, "r") as f:
            log_data = json.load(f)
    
        keys_to_delete = []  # List to keep track of keys to delete
        # Check each entry in the log
        for unique_id, upload_time in log_data.items():
            upload_time_dt = datetime.fromisoformat(upload_time)
            if current_time - upload_time_dt > EXPIRATION_TIME:
                # Add key to the list for deletion
                keys_to_delete.append(unique_id)
                
                # Remove files if expired
                pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
                vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
                
                if os.path.isfile(pdf_file_path):
                    os.remove(pdf_file_path)
                if os.path.isdir(vector_db_path):
                    shutil.rmtree(vector_db_path)
        
        # Now delete the keys from log_data after iteration
        for key in keys_to_delete:
            del log_data[key]
        
        # Save updated log
        with open(LOG_FILE, "w") as f:
            json.dump(log_data, f)

## Context Taking, PDF Upload, and Mode Selection
with st.sidebar:
    st.title("Select Mode:")
    option = st.selectbox(
        'Choose your interaction mode',
        ('Chat PDF', 'Chat LLM', 'Graph and Table', 'Code', 'Custom Prompting')
    )
    
    if option == "Chat PDF":
        st.title("Upload PDF:")
        research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
        if not research_field:
            st.info("Please enter a research field to proceed.")
            uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
        else:
            uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
    else:
        research_field = None
        uploaded_file = None

    temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
    selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
    top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)

## Initialize unique ID, db_client, db_path, and timestamp if needed
if 'db_client' not in st.session_state and option == "Chat PDF":
    unique_id = str(uuid.uuid4())
    st.session_state['unique_id'] = unique_id
    db_path = os.path.join(VECTOR_DB_DIR, unique_id)
    os.makedirs(db_path, exist_ok=True)
    st.session_state['db_path'] = db_path
    st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)

    # Log the upload time
    log_upload_time(unique_id)

# Access session-stored variables
if option == "Chat PDF":
    db_client = st.session_state['db_client']
    unique_id = st.session_state['unique_id']
    db_path = st.session_state['db_path']

    if 'document_text' not in st.session_state:
        st.session_state['document_text'] = None

    if 'text_embeddings' not in st.session_state:
        st.session_state['text_embeddings'] = None

## Handle PDF Upload and Processing
if option == "Chat PDF" and uploaded_file is not None and st.session_state['document_text'] is None:
    os.makedirs(UPLOAD_DIR, exist_ok=True)
    file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
    with open(file_path, "wb") as file:
        file.write(uploaded_file.getvalue())

    document_text = PdfConverter(file_path).convert_to_markdown()
    st.session_state['document_text'] = document_text
    
    text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
    text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
    st.session_state['text_embeddings'] = text_contents_embeddings

if option == "Chat PDF" and st.session_state['document_text'] and st.session_state['text_embeddings']:
    document_text = st.session_state['document_text']
    text_contents_embeddings = st.session_state['text_embeddings']
else:
    if option == "Chat PDF":
        st.stop()

## Chat Input for Both Modes
q_input = st.chat_input(key="input", placeholder="Ask your question")

if q_input:
    if option == "Chat PDF":
        query_embedding = ragQuery(embeddModel, q_input)
        top_k_matches = similarityChroma(query_embedding, db_client, top_k)

        LLMmodel = llm_model[selected_llm_model]
        domain = research_field
        prompt_template = q_input
        user_content = top_k_matches
        max_tokens = max_tokens[selected_llm_model]

        groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p=1, stream=True, stop=None)
        result = groq_completion.create_completion()
        
        with st.spinner("Processing..."):
            chat_response(q_input, result)

    elif option == "Chat LLM":
        LLMmodel = llm_model[selected_llm_model]
        domain = "General"
        prompt_template = q_input
        user_content = ""
        max_tokens = max_tokens[selected_llm_model]

        groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p=1, stream=True, stop=None)
        result = groq_completion.create_completion()
        
        with st.spinner("Processing..."):
            chat_response(q_input, result)

## Periodic Cleanup
cleanup_expired_files()