Kurian07 commited on
Commit
e59d3be
·
verified ·
1 Parent(s): 60fc5e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -189
app.py CHANGED
@@ -1,189 +1,189 @@
1
- import streamlit as st
2
- import os
3
- import uuid
4
- import shutil
5
- from datetime import datetime, timedelta
6
- from dotenv import load_dotenv
7
- from chatMode import chat_response
8
- from modules.pdfExtractor import PdfConverter
9
- from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
10
- from sentence_transformers import SentenceTransformer
11
- from modules.llm import GroqClient, GroqCompletion
12
- import chromadb
13
- import json
14
-
15
- # Load environment variables
16
- load_dotenv()
17
-
18
- ######## Embedding Model ########
19
- embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
20
- embeddModel.max_seq_length = 512
21
- chunk_size, chunk_overlap, top_k_default = 2000, 200, 5
22
-
23
- ######## Groq to LLM Connect ########
24
- api_key = os.getenv("GROQ_API_KEY")
25
- groq_client = GroqClient(api_key)
26
- llm_model = {
27
- "Gemma9B": "gemma2-9b-it",
28
- "Gemma7B": "gemma-7b-it",
29
- "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",
30
- "LLama3.1-70B": "llama-3.1-70b-versatile",
31
- "LLama3-70B": "llama3-70b-8192",
32
- "LLama3.2-90B": "llama-3.2-90b-text-preview",
33
- "Mixtral8x7B": "mixtral-8x7b-32768"
34
- }
35
- max_tokens = {
36
- "Gemma9B": 8192,
37
- "Gemma7B": 8192,
38
- "LLama3-70B": 8192,
39
- "LLama3.1-70B": 8000,
40
- "LLama3-70B": 8192,
41
- "LLama3.2-90B": 8192,
42
- "Mixtral8x7B": 32768
43
- }
44
-
45
- ## Time-based cleanup settings
46
- EXPIRATION_TIME = timedelta(hours=6)
47
- UPLOAD_DIR = "Uploaded"
48
- VECTOR_DB_DIR = "vectorDB"
49
- LOG_FILE = "upload_log.json"
50
-
51
- ## Initialize Streamlit app
52
- st.set_page_config(page_title="ChatPDF", layout="wide")
53
- st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)
54
-
55
- ## Function to log upload time
56
- def log_upload_time(unique_id):
57
- upload_time = datetime.now().isoformat()
58
- log_entry = {unique_id: upload_time}
59
- if os.path.exists(LOG_FILE):
60
- with open(LOG_FILE, "r") as f:
61
- log_data = json.load(f)
62
- log_data.update(log_entry)
63
- else:
64
- log_data = log_entry
65
-
66
- with open(LOG_FILE, "w") as f:
67
- json.dump(log_data, f)
68
-
69
- ## Cleanup expired files based on log
70
- def cleanup_expired_files():
71
- current_time = datetime.now()
72
-
73
- # Load upload log
74
- if os.path.exists(LOG_FILE):
75
- with open(LOG_FILE, "r") as f:
76
- log_data = json.load(f)
77
-
78
- keys_to_delete = [] # List to keep track of keys to delete
79
- # Check each entry in the log
80
- for unique_id, upload_time in log_data.items():
81
- upload_time_dt = datetime.fromisoformat(upload_time)
82
- if current_time - upload_time_dt > EXPIRATION_TIME:
83
- # Add key to the list for deletion
84
- keys_to_delete.append(unique_id)
85
-
86
- # Remove files if expired
87
- pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
88
- vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
89
-
90
- if os.path.isfile(pdf_file_path):
91
- os.remove(pdf_file_path)
92
- if os.path.isdir(vector_db_path):
93
- shutil.rmtree(vector_db_path)
94
-
95
- # Now delete the keys from log_data after iteration
96
- for key in keys_to_delete:
97
- del log_data[key]
98
-
99
- # Save updated log
100
- with open(LOG_FILE, "w") as f:
101
- json.dump(log_data, f)
102
-
103
- ## Context Taking, PDF Upload, and Mode Selection
104
- with st.sidebar:
105
- st.title("Upload PDF:")
106
-
107
- research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
108
- option = ''
109
-
110
- if not research_field:
111
- st.info("Please enter a research field to proceed.")
112
- option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
113
- uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
114
- else:
115
- option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
116
- uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
117
-
118
- temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
119
- selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
120
- top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)
121
-
122
- ## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
123
- if 'db_client' not in st.session_state:
124
- unique_id = str(uuid.uuid4())
125
- st.session_state['unique_id'] = unique_id
126
- db_path = os.path.join(VECTOR_DB_DIR, unique_id)
127
- os.makedirs(db_path, exist_ok=True)
128
- st.session_state['db_path'] = db_path
129
- st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)
130
-
131
- # Log the upload time
132
- log_upload_time(unique_id)
133
-
134
- # Access session-stored variables
135
- db_client = st.session_state['db_client']
136
- unique_id = st.session_state['unique_id']
137
- db_path = st.session_state['db_path']
138
-
139
- if 'document_text' not in st.session_state:
140
- st.session_state['document_text'] = None
141
-
142
- if 'text_embeddings' not in st.session_state:
143
- st.session_state['text_embeddings'] = None
144
-
145
- ## Handle PDF Upload and Processing
146
- if uploaded_file is not None and st.session_state['document_text'] is None:
147
- os.makedirs(UPLOAD_DIR, exist_ok=True)
148
- file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
149
- with open(file_path, "wb") as file:
150
- file.write(uploaded_file.getvalue())
151
-
152
- document_text = PdfConverter(file_path).convert_to_markdown()
153
- st.session_state['document_text'] = document_text
154
-
155
- text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
156
- text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
157
- st.session_state['text_embeddings'] = text_contents_embeddings
158
-
159
- if st.session_state['document_text'] and st.session_state['text_embeddings']:
160
- document_text = st.session_state['document_text']
161
- text_contents_embeddings = st.session_state['text_embeddings']
162
- else:
163
- st.stop()
164
-
165
- q_input = st.chat_input(key="input", placeholder="Ask your question")
166
-
167
- if q_input:
168
- if option == "Chat":
169
- query_embedding = ragQuery(embeddModel, q_input)
170
- top_k_matches = similarityChroma(query_embedding, db_client, top_k)
171
-
172
- LLMmodel = llm_model[selected_llm_model]
173
- domain = research_field
174
- prompt_template = q_input
175
- user_content = top_k_matches
176
- max_tokens = max_tokens[selected_llm_model]
177
- print(max_tokens)
178
- top_p = 1
179
- stream = True
180
- stop = None
181
-
182
- groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
183
- result = groq_completion.create_completion()
184
-
185
- with st.spinner("Processing..."):
186
- chat_response(q_input, result)
187
-
188
- ## Call the cleanup function periodically
189
- cleanup_expired_files()
 
1
+ import streamlit as st
2
+ import os
3
+ import uuid
4
+ import shutil
5
+ from datetime import datetime, timedelta
6
+ from dotenv import load_dotenv
7
+ from chatMode import chat_response
8
+ from modules.pdfExtractor import PdfConverter
9
+ from modules.rag import contextChunks, contextEmbeddingChroma, retrieveEmbeddingsChroma, ragQuery, similarityChroma
10
+ from sentence_transformers import SentenceTransformer
11
+ from modules.llm import GroqClient, GroqCompletion
12
+ import chromadb
13
+ import json
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+
18
+ ######## Embedding Model ########
19
+ embeddModel = SentenceTransformer(os.path.join(os.getcwd(), "embeddingModel"))
20
+ embeddModel.max_seq_length = 512
21
+ chunk_size, chunk_overlap, top_k_default = 1000, 300, 5
22
+
23
+ ######## Groq to LLM Connect ########
24
+ api_key = os.getenv("GROQ_API_KEY")
25
+ groq_client = GroqClient(api_key)
26
+ llm_model = {
27
+ "Gemma9B": "gemma2-9b-it",
28
+ "Gemma7B": "gemma-7b-it",
29
+ "LLama3-70B-Preview": "llama3-groq-70b-8192-tool-use-preview",
30
+ "LLama3.1-70B": "llama-3.1-70b-versatile",
31
+ "LLama3-70B": "llama3-70b-8192",
32
+ "LLama3.2-90B": "llama-3.2-90b-text-preview",
33
+ "Mixtral8x7B": "mixtral-8x7b-32768"
34
+ }
35
+ max_tokens = {
36
+ "Gemma9B": 8192,
37
+ "Gemma7B": 8192,
38
+ "LLama3-70B": 8192,
39
+ "LLama3.1-70B": 8000,
40
+ "LLama3-70B": 8192,
41
+ "LLama3.2-90B": 8192,
42
+ "Mixtral8x7B": 32768
43
+ }
44
+
45
+ ## Time-based cleanup settings
46
+ EXPIRATION_TIME = timedelta(hours=6)
47
+ UPLOAD_DIR = "Uploaded"
48
+ VECTOR_DB_DIR = "vectorDB"
49
+ LOG_FILE = "upload_log.json"
50
+
51
+ ## Initialize Streamlit app
52
+ st.set_page_config(page_title="ChatPDF", layout="wide")
53
+ st.markdown("<h2 style='text-align: center;'>chatPDF</h2>", unsafe_allow_html=True)
54
+
55
+ ## Function to log upload time
56
+ def log_upload_time(unique_id):
57
+ upload_time = datetime.now().isoformat()
58
+ log_entry = {unique_id: upload_time}
59
+ if os.path.exists(LOG_FILE):
60
+ with open(LOG_FILE, "r") as f:
61
+ log_data = json.load(f)
62
+ log_data.update(log_entry)
63
+ else:
64
+ log_data = log_entry
65
+
66
+ with open(LOG_FILE, "w") as f:
67
+ json.dump(log_data, f)
68
+
69
+ ## Cleanup expired files based on log
70
+ def cleanup_expired_files():
71
+ current_time = datetime.now()
72
+
73
+ # Load upload log
74
+ if os.path.exists(LOG_FILE):
75
+ with open(LOG_FILE, "r") as f:
76
+ log_data = json.load(f)
77
+
78
+ keys_to_delete = [] # List to keep track of keys to delete
79
+ # Check each entry in the log
80
+ for unique_id, upload_time in log_data.items():
81
+ upload_time_dt = datetime.fromisoformat(upload_time)
82
+ if current_time - upload_time_dt > EXPIRATION_TIME:
83
+ # Add key to the list for deletion
84
+ keys_to_delete.append(unique_id)
85
+
86
+ # Remove files if expired
87
+ pdf_file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
88
+ vector_db_path = os.path.join(VECTOR_DB_DIR, unique_id)
89
+
90
+ if os.path.isfile(pdf_file_path):
91
+ os.remove(pdf_file_path)
92
+ if os.path.isdir(vector_db_path):
93
+ shutil.rmtree(vector_db_path)
94
+
95
+ # Now delete the keys from log_data after iteration
96
+ for key in keys_to_delete:
97
+ del log_data[key]
98
+
99
+ # Save updated log
100
+ with open(LOG_FILE, "w") as f:
101
+ json.dump(log_data, f)
102
+
103
+ ## Context Taking, PDF Upload, and Mode Selection
104
+ with st.sidebar:
105
+ st.title("Upload PDF:")
106
+
107
+ research_field = st.text_input("Research Field: ", key="research_field", placeholder="Enter research fields with commas")
108
+ option = ''
109
+
110
+ if not research_field:
111
+ st.info("Please enter a research field to proceed.")
112
+ option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'), disabled=True)
113
+ uploaded_file = st.file_uploader("", type=["pdf"], disabled=True)
114
+ else:
115
+ option = st.selectbox('Select Mode', ('Chat', 'Graph and Table', 'Code', 'Custom Prompting'))
116
+ uploaded_file = st.file_uploader("", type=["pdf"], disabled=False)
117
+
118
+ temperature = st.slider("Select Temperature", min_value=0.0, max_value=1.0, value=0.05, step=0.01)
119
+ selected_llm_model = st.selectbox("Select LLM Model", options=list(llm_model.keys()), index=3)
120
+ top_k = st.slider("Select Top K Matches", min_value=1, max_value=20, value=5)
121
+
122
+ ## Initialize unique ID, db_client, db_path, and timestamp if not already in session state
123
+ if 'db_client' not in st.session_state:
124
+ unique_id = str(uuid.uuid4())
125
+ st.session_state['unique_id'] = unique_id
126
+ db_path = os.path.join(VECTOR_DB_DIR, unique_id)
127
+ os.makedirs(db_path, exist_ok=True)
128
+ st.session_state['db_path'] = db_path
129
+ st.session_state['db_client'] = chromadb.PersistentClient(path=db_path)
130
+
131
+ # Log the upload time
132
+ log_upload_time(unique_id)
133
+
134
+ # Access session-stored variables
135
+ db_client = st.session_state['db_client']
136
+ unique_id = st.session_state['unique_id']
137
+ db_path = st.session_state['db_path']
138
+
139
+ if 'document_text' not in st.session_state:
140
+ st.session_state['document_text'] = None
141
+
142
+ if 'text_embeddings' not in st.session_state:
143
+ st.session_state['text_embeddings'] = None
144
+
145
+ ## Handle PDF Upload and Processing
146
+ if uploaded_file is not None and st.session_state['document_text'] is None:
147
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
148
+ file_path = os.path.join(UPLOAD_DIR, f"{unique_id}_paper.pdf")
149
+ with open(file_path, "wb") as file:
150
+ file.write(uploaded_file.getvalue())
151
+
152
+ document_text = PdfConverter(file_path).convert_to_markdown()
153
+ st.session_state['document_text'] = document_text
154
+
155
+ text_content_chunks = contextChunks(document_text, chunk_size, chunk_overlap)
156
+ text_contents_embeddings = contextEmbeddingChroma(embeddModel, text_content_chunks, db_client, db_path=db_path)
157
+ st.session_state['text_embeddings'] = text_contents_embeddings
158
+
159
+ if st.session_state['document_text'] and st.session_state['text_embeddings']:
160
+ document_text = st.session_state['document_text']
161
+ text_contents_embeddings = st.session_state['text_embeddings']
162
+ else:
163
+ st.stop()
164
+
165
+ q_input = st.chat_input(key="input", placeholder="Ask your question")
166
+
167
+ if q_input:
168
+ if option == "Chat":
169
+ query_embedding = ragQuery(embeddModel, q_input)
170
+ top_k_matches = similarityChroma(query_embedding, db_client, top_k)
171
+
172
+ LLMmodel = llm_model[selected_llm_model]
173
+ domain = research_field
174
+ prompt_template = q_input
175
+ user_content = top_k_matches
176
+ max_tokens = max_tokens[selected_llm_model]
177
+ print(max_tokens)
178
+ top_p = 1
179
+ stream = True
180
+ stop = None
181
+
182
+ groq_completion = GroqCompletion(groq_client, LLMmodel, domain, prompt_template, user_content, temperature, max_tokens, top_p, stream, stop)
183
+ result = groq_completion.create_completion()
184
+
185
+ with st.spinner("Processing..."):
186
+ chat_response(q_input, result)
187
+
188
+ ## Call the cleanup function periodically
189
+ cleanup_expired_files()