Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,189 +1,189 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import os
|
3 |
-
import uuid
|
4 |
-
import shutil
|
5 |
-
from datetime import datetime, timedelta
|
6 |
-
from dotenv import load_dotenv
|
7 |
-
from chatMode import chat_response
|
8 |
-
from modules.pdfExtractor import PdfConverter
|
9 |
-
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
|
10 |
-
from sentence_transformers import SentenceTransformer
|
11 |
-
from modules.llm import GroqClient, GroqCompletion
|
12 |
-
import chromadb
|
13 |
-
import json
|
14 |
-
|
15 |
-
# Load environment variables
|
16 |
-
load_dotenv()
|
17 |
-
|
18 |
-
######## Embedding Model ########
|
19 |
-
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
|
20 |
-
embeddModel.max_seq_length = 512
|
21 |
-
chunk_size, chunk_overlap, top_k_default =
|
22 |
-
|
23 |
-
######## Groq to LLM Connect ########
|
24 |
-
api_key = os.getenv("GROQ_API_KEY")
|
25 |
-
groq_client = GroqClient(api_key)
|
26 |
-
llm_model = {
|
27 |
-
"Gemma9B": "gemma2-9b-it",
|
28 |
-
"Gemma7B": "gemma-7b-it",
|
29 |
-
"LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",
|
30 |
-
"LLama3.1-70B": "llama-3.1-70b-versatile",
|
31 |
-
"LLama3-70B": "llama3-70b-8192",
|
32 |
-
"LLama3.2-90B": "llama-3.2-90b-text-preview",
|
33 |
-
"Mixtral8x7B": "mixtral-8x7b-32768"
|
34 |
-
}
|
35 |
-
max_tokens = {
|
36 |
-
"Gemma9B": 8192,
|
37 |
-
"Gemma7B": 8192,
|
38 |
-
"LLama3-70B": 8192,
|
39 |
-
"LLama3.1-70B": 8000,
|
40 |
-
"LLama3-70B": 8192,
|
41 |
-
"LLama3.2-90B": 8192,
|
42 |
-
"Mixtral8x7B": 32768
|
43 |
-
}
|
44 |
-
|
45 |
-
## Time-based cleanup settings
|
46 |
-
EXPIRATION_TIME = timedelta(hours=6)
|
47 |
-
UPLOAD_DIR = "Uploaded"
|
48 |
-
VECTOR_DB_DIR = "vectorDB"
|
49 |
-
LOG_FILE = "upload_log.json"
|
50 |
-
|
51 |
-
## Initialize Streamlit app
|
52 |
-
st.set_page_config(page_title="ChatPDF", layout="wide")
|
53 |
-
st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)
|
54 |
-
|
55 |
-
## Function to log upload time
|
56 |
-
def log_upload_time(unique_id):
|
57 |
-
upload_time = datetime.now().isoformat()
|
58 |
-
log_entry = {unique_id: upload_time}
|
59 |
-
if os.path.exists(LOG_FILE):
|
60 |
-
with open(LOG_FILE, "r") as f:
|
61 |
-
log_data = json.load(f)
|
62 |
-
log_data.update(log_entry)
|
63 |
-
else:
|
64 |
-
log_data = log_entry
|
65 |
-
|
66 |
-
with open(LOG_FILE, "w") as f:
|
67 |
-
json.dump(log_data, f)
|
68 |
-
|
69 |
-
## Cleanup expired files based on log
|
70 |
-
def cleanup_expired_files():
|
71 |
-
current_time = datetime.now()
|
72 |
-
|
73 |
-
# Load upload log
|
74 |
-
if os.path.exists(LOG_FILE):
|
75 |
-
with open(LOG_FILE, "r") as f:
|
76 |
-
log_data = json.load(f)
|
77 |
-
|
78 |
-
keys_to_delete = [] # List to keep track of keys to delete
|
79 |
-
# Check each entry in the log
|
80 |
-
for unique_id, upload_time in log_data.items():
|
81 |
-
upload_time_dt = datetime.fromisoformat(upload_time)
|
82 |
-
if current_time - upload_time_dt > EXPIRATION_TIME:
|
83 |
-
# Add key to the list for deletion
|
84 |
-
keys_to_delete.append(unique_id)
|
85 |
-
|
86 |
-
# Remove files if expired
|
87 |
-
pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
|
88 |
-
vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
|
89 |
-
|
90 |
-
if os.path.isfile(pdf_file_path):
|
91 |
-
os.remove(pdf_file_path)
|
92 |
-
if os.path.isdir(vector_db_path):
|
93 |
-
shutil.rmtree(vector_db_path)
|
94 |
-
|
95 |
-
# Now delete the keys from log_data after iteration
|
96 |
-
for key in keys_to_delete:
|
97 |
-
del log_data[key]
|
98 |
-
|
99 |
-
# Save updated log
|
100 |
-
with open(LOG_FILE, "w") as f:
|
101 |
-
json.dump(log_data, f)
|
102 |
-
|
103 |
-
## Context Taking, PDF Upload, and Mode Selection
|
104 |
-
with st.sidebar:
|
105 |
-
st.title("Upload PDF:")
|
106 |
-
|
107 |
-
research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
|
108 |
-
option = ''
|
109 |
-
|
110 |
-
if not research_field:
|
111 |
-
st.info("Please enter a research field to proceed.")
|
112 |
-
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
|
113 |
-
uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
|
114 |
-
else:
|
115 |
-
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
|
116 |
-
uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
|
117 |
-
|
118 |
-
temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
|
119 |
-
selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
|
120 |
-
top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)
|
121 |
-
|
122 |
-
## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
|
123 |
-
if 'db_client' not in st.session_state:
|
124 |
-
unique_id = str(uuid.uuid4())
|
125 |
-
st.session_state['unique_id'] = unique_id
|
126 |
-
db_path = os.path.join(VECTOR_DB_DIR, unique_id)
|
127 |
-
os.makedirs(db_path, exist_ok=True)
|
128 |
-
st.session_state['db_path'] = db_path
|
129 |
-
st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)
|
130 |
-
|
131 |
-
# Log the upload time
|
132 |
-
log_upload_time(unique_id)
|
133 |
-
|
134 |
-
# Access session-stored variables
|
135 |
-
db_client = st.session_state['db_client']
|
136 |
-
unique_id = st.session_state['unique_id']
|
137 |
-
db_path = st.session_state['db_path']
|
138 |
-
|
139 |
-
if 'document_text' not in st.session_state:
|
140 |
-
st.session_state['document_text'] = None
|
141 |
-
|
142 |
-
if 'text_embeddings' not in st.session_state:
|
143 |
-
st.session_state['text_embeddings'] = None
|
144 |
-
|
145 |
-
## Handle PDF Upload and Processing
|
146 |
-
if uploaded_file is not None and st.session_state['document_text'] is None:
|
147 |
-
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
148 |
-
file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
|
149 |
-
with open(file_path, "wb") as file:
|
150 |
-
file.write(uploaded_file.getvalue())
|
151 |
-
|
152 |
-
document_text = PdfConverter(file_path).convert_to_markdown()
|
153 |
-
st.session_state['document_text'] = document_text
|
154 |
-
|
155 |
-
text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
|
156 |
-
text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
|
157 |
-
st.session_state['text_embeddings'] = text_contents_embeddings
|
158 |
-
|
159 |
-
if st.session_state['document_text'] and st.session_state['text_embeddings']:
|
160 |
-
document_text = st.session_state['document_text']
|
161 |
-
text_contents_embeddings = st.session_state['text_embeddings']
|
162 |
-
else:
|
163 |
-
st.stop()
|
164 |
-
|
165 |
-
q_input = st.chat_input(key="input", placeholder="Ask your question")
|
166 |
-
|
167 |
-
if q_input:
|
168 |
-
if option == "Chat":
|
169 |
-
query_embedding = ragQuery(embeddModel, q_input)
|
170 |
-
top_k_matches = similarityChroma(query_embedding, db_client, top_k)
|
171 |
-
|
172 |
-
LLMmodel = llm_model[selected_llm_model]
|
173 |
-
domain = research_field
|
174 |
-
prompt_template = q_input
|
175 |
-
user_content = top_k_matches
|
176 |
-
max_tokens = max_tokens[selected_llm_model]
|
177 |
-
print(max_tokens)
|
178 |
-
top_p = 1
|
179 |
-
stream = True
|
180 |
-
stop = None
|
181 |
-
|
182 |
-
groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
|
183 |
-
result = groq_completion.create_completion()
|
184 |
-
|
185 |
-
with st.spinner("Processing..."):
|
186 |
-
chat_response(q_input, result)
|
187 |
-
|
188 |
-
## Call the cleanup function periodically
|
189 |
-
cleanup_expired_files()
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
import shutil
|
5 |
+
from datetime import datetime, timedelta
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from chatMode import chat_response
|
8 |
+
from modules.pdfExtractor import PdfConverter
|
9 |
+
from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from modules.llm import GroqClient, GroqCompletion
|
12 |
+
import chromadb
|
13 |
+
import json
|
14 |
+
|
15 |
+
# Load environment variables
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
######## Embedding Model ########
|
19 |
+
embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
|
20 |
+
embeddModel.max_seq_length = 512
|
21 |
+
chunk_size, chunk_overlap, top_k_default = 1000, 300, 5
|
22 |
+
|
23 |
+
######## Groq to LLM Connect ########
|
24 |
+
api_key = os.getenv("GROQ_API_KEY")
|
25 |
+
groq_client = GroqClient(api_key)
|
26 |
+
llm_model = {
|
27 |
+
"Gemma9B": "gemma2-9b-it",
|
28 |
+
"Gemma7B": "gemma-7b-it",
|
29 |
+
"LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",
|
30 |
+
"LLama3.1-70B": "llama-3.1-70b-versatile",
|
31 |
+
"LLama3-70B": "llama3-70b-8192",
|
32 |
+
"LLama3.2-90B": "llama-3.2-90b-text-preview",
|
33 |
+
"Mixtral8x7B": "mixtral-8x7b-32768"
|
34 |
+
}
|
35 |
+
max_tokens = {
|
36 |
+
"Gemma9B": 8192,
|
37 |
+
"Gemma7B": 8192,
|
38 |
+
"LLama3-70B": 8192,
|
39 |
+
"LLama3.1-70B": 8000,
|
40 |
+
"LLama3-70B": 8192,
|
41 |
+
"LLama3.2-90B": 8192,
|
42 |
+
"Mixtral8x7B": 32768
|
43 |
+
}
|
44 |
+
|
45 |
+
## Time-based cleanup settings
|
46 |
+
EXPIRATION_TIME = timedelta(hours=6)
|
47 |
+
UPLOAD_DIR = "Uploaded"
|
48 |
+
VECTOR_DB_DIR = "vectorDB"
|
49 |
+
LOG_FILE = "upload_log.json"
|
50 |
+
|
51 |
+
## Initialize Streamlit app
|
52 |
+
st.set_page_config(page_title="ChatPDF", layout="wide")
|
53 |
+
st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)
|
54 |
+
|
55 |
+
## Function to log upload time
|
56 |
+
def log_upload_time(unique_id):
|
57 |
+
upload_time = datetime.now().isoformat()
|
58 |
+
log_entry = {unique_id: upload_time}
|
59 |
+
if os.path.exists(LOG_FILE):
|
60 |
+
with open(LOG_FILE, "r") as f:
|
61 |
+
log_data = json.load(f)
|
62 |
+
log_data.update(log_entry)
|
63 |
+
else:
|
64 |
+
log_data = log_entry
|
65 |
+
|
66 |
+
with open(LOG_FILE, "w") as f:
|
67 |
+
json.dump(log_data, f)
|
68 |
+
|
69 |
+
## Cleanup expired files based on log
|
70 |
+
def cleanup_expired_files():
|
71 |
+
current_time = datetime.now()
|
72 |
+
|
73 |
+
# Load upload log
|
74 |
+
if os.path.exists(LOG_FILE):
|
75 |
+
with open(LOG_FILE, "r") as f:
|
76 |
+
log_data = json.load(f)
|
77 |
+
|
78 |
+
keys_to_delete = [] # List to keep track of keys to delete
|
79 |
+
# Check each entry in the log
|
80 |
+
for unique_id, upload_time in log_data.items():
|
81 |
+
upload_time_dt = datetime.fromisoformat(upload_time)
|
82 |
+
if current_time - upload_time_dt > EXPIRATION_TIME:
|
83 |
+
# Add key to the list for deletion
|
84 |
+
keys_to_delete.append(unique_id)
|
85 |
+
|
86 |
+
# Remove files if expired
|
87 |
+
pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
|
88 |
+
vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
|
89 |
+
|
90 |
+
if os.path.isfile(pdf_file_path):
|
91 |
+
os.remove(pdf_file_path)
|
92 |
+
if os.path.isdir(vector_db_path):
|
93 |
+
shutil.rmtree(vector_db_path)
|
94 |
+
|
95 |
+
# Now delete the keys from log_data after iteration
|
96 |
+
for key in keys_to_delete:
|
97 |
+
del log_data[key]
|
98 |
+
|
99 |
+
# Save updated log
|
100 |
+
with open(LOG_FILE, "w") as f:
|
101 |
+
json.dump(log_data, f)
|
102 |
+
|
103 |
+
## Context Taking, PDF Upload, and Mode Selection
|
104 |
+
with st.sidebar:
|
105 |
+
st.title("Upload PDF:")
|
106 |
+
|
107 |
+
research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
|
108 |
+
option = ''
|
109 |
+
|
110 |
+
if not research_field:
|
111 |
+
st.info("Please enter a research field to proceed.")
|
112 |
+
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
|
113 |
+
uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
|
114 |
+
else:
|
115 |
+
option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
|
116 |
+
uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
|
117 |
+
|
118 |
+
temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
|
119 |
+
selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
|
120 |
+
top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)
|
121 |
+
|
122 |
+
## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
|
123 |
+
if 'db_client' not in st.session_state:
|
124 |
+
unique_id = str(uuid.uuid4())
|
125 |
+
st.session_state['unique_id'] = unique_id
|
126 |
+
db_path = os.path.join(VECTOR_DB_DIR, unique_id)
|
127 |
+
os.makedirs(db_path, exist_ok=True)
|
128 |
+
st.session_state['db_path'] = db_path
|
129 |
+
st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)
|
130 |
+
|
131 |
+
# Log the upload time
|
132 |
+
log_upload_time(unique_id)
|
133 |
+
|
134 |
+
# Access session-stored variables
|
135 |
+
db_client = st.session_state['db_client']
|
136 |
+
unique_id = st.session_state['unique_id']
|
137 |
+
db_path = st.session_state['db_path']
|
138 |
+
|
139 |
+
if 'document_text' not in st.session_state:
|
140 |
+
st.session_state['document_text'] = None
|
141 |
+
|
142 |
+
if 'text_embeddings' not in st.session_state:
|
143 |
+
st.session_state['text_embeddings'] = None
|
144 |
+
|
145 |
+
## Handle PDF Upload and Processing
|
146 |
+
if uploaded_file is not None and st.session_state['document_text'] is None:
|
147 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
148 |
+
file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
|
149 |
+
with open(file_path, "wb") as file:
|
150 |
+
file.write(uploaded_file.getvalue())
|
151 |
+
|
152 |
+
document_text = PdfConverter(file_path).convert_to_markdown()
|
153 |
+
st.session_state['document_text'] = document_text
|
154 |
+
|
155 |
+
text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
|
156 |
+
text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
|
157 |
+
st.session_state['text_embeddings'] = text_contents_embeddings
|
158 |
+
|
159 |
+
if st.session_state['document_text'] and st.session_state['text_embeddings']:
|
160 |
+
document_text = st.session_state['document_text']
|
161 |
+
text_contents_embeddings = st.session_state['text_embeddings']
|
162 |
+
else:
|
163 |
+
st.stop()
|
164 |
+
|
165 |
+
q_input = st.chat_input(key="input", placeholder="Ask your question")
|
166 |
+
|
167 |
+
if q_input:
|
168 |
+
if option == "Chat":
|
169 |
+
query_embedding = ragQuery(embeddModel, q_input)
|
170 |
+
top_k_matches = similarityChroma(query_embedding, db_client, top_k)
|
171 |
+
|
172 |
+
LLMmodel = llm_model[selected_llm_model]
|
173 |
+
domain = research_field
|
174 |
+
prompt_template = q_input
|
175 |
+
user_content = top_k_matches
|
176 |
+
max_tokens = max_tokens[selected_llm_model]
|
177 |
+
print(max_tokens)
|
178 |
+
top_p = 1
|
179 |
+
stream = True
|
180 |
+
stop = None
|
181 |
+
|
182 |
+
groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
|
183 |
+
result = groq_completion.create_completion()
|
184 |
+
|
185 |
+
with st.spinner("Processing..."):
|
186 |
+
chat_response(q_input, result)
|
187 |
+
|
188 |
+
## Call the cleanup function periodically
|
189 |
+
cleanup_expired_files()
|