Tesneem commited on
Commit
6e9cc55
·
verified ·
1 Parent(s): 740868b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -439
app.py CHANGED
@@ -1,439 +1,226 @@
1
- #############################################################################################################################
2
- # Filename : app.py
3
- # Description: A Streamlit application to showcase how RAG works.
4
- # Author : Georgios Ioannou
5
- #
6
- # Copyright © 2024 by Georgios Ioannou
7
- #############################################################################################################################
8
- # Import libraries.
9
- import os
10
- import streamlit as st
11
-
12
- from dotenv import load_dotenv, find_dotenv
13
- from huggingface_hub import InferenceClient
14
- from langchain.prompts import PromptTemplate
15
- from langchain.schema import Document
16
- from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
17
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
18
- from langchain_community.vectorstores import MongoDBAtlasVectorSearch
19
- from pymongo import MongoClient
20
- from pymongo.collection import Collection
21
- from typing import Dict, Any
22
-
23
-
24
- #############################################################################################################################
25
-
26
-
27
- class RAGQuestionAnswering:
28
- def __init__(self):
29
- """
30
- Parameters
31
- ----------
32
- None
33
-
34
- Output
35
- ------
36
- None
37
-
38
- Purpose
39
- -------
40
- Initializes the RAG Question Answering system by setting up configuration
41
- and loading environment variables.
42
-
43
- Assumptions
44
- -----------
45
- - Expects .env file with MONGO_URI and HF_TOKEN
46
- - Requires proper MongoDB setup with vector search index
47
- - Needs connection to Hugging Face API
48
-
49
- Notes
50
- -----
51
- This is the main class that handles all RAG operations
52
- """
53
- self.load_environment()
54
- self.setup_mongodb()
55
- self.setup_embedding_model()
56
- self.setup_vector_search()
57
- self.setup_rag_chain()
58
-
59
- def load_environment(self) -> None:
60
- """
61
- Parameters
62
- ----------
63
- None
64
-
65
- Output
66
- ------
67
- None
68
-
69
- Purpose
70
- -------
71
- Loads environment variables from .env file and sets up configuration constants.
72
-
73
- Assumptions
74
- -----------
75
- Expects a .env file with MONGO_URI and HF_TOKEN defined
76
-
77
- Notes
78
- -----
79
- Will stop the application if required environment variables are missing
80
- """
81
-
82
- load_dotenv(find_dotenv())
83
- self.MONGO_URI = os.getenv("MONGO_URI")
84
- self.HF_TOKEN = os.getenv("HF_TOKEN")
85
-
86
- if not self.MONGO_URI or not self.HF_TOKEN:
87
- st.error("Please ensure MONGO_URI and HF_TOKEN are set in your .env file")
88
- st.stop()
89
-
90
- # MongoDB configuration.
91
- self.DB_NAME = "txts"
92
- self.COLLECTION_NAME = "txts_collection"
93
- self.VECTOR_SEARCH_INDEX = "vector_index"
94
-
95
- def setup_mongodb(self) -> None:
96
- """
97
- Parameters
98
- ----------
99
- None
100
-
101
- Output
102
- ------
103
- None
104
-
105
- Purpose
106
- -------
107
- Initializes the MongoDB connection and sets up the collection.
108
-
109
- Assumptions
110
- -----------
111
- - Valid MongoDB URI is available
112
- - Database and collection exist in MongoDB Atlas
113
-
114
- Notes
115
- -----
116
- Uses st.cache_resource for efficient connection management
117
- """
118
-
119
- @st.cache_resource
120
- def init_mongodb() -> Collection:
121
- cluster = MongoClient(self.MONGO_URI)
122
- return cluster[self.DB_NAME][self.COLLECTION_NAME]
123
-
124
- self.mongodb_collection = init_mongodb()
125
-
126
- def setup_embedding_model(self) -> None:
127
- """
128
- Parameters
129
- ----------
130
- None
131
-
132
- Output
133
- ------
134
- None
135
-
136
- Purpose
137
- -------
138
- Initializes the embedding model for vector search.
139
-
140
- Assumptions
141
- -----------
142
- - Valid Hugging Face API token
143
- - Internet connection to access the model
144
-
145
- Notes
146
- -----
147
- Uses the all-mpnet-base-v2 model from sentence-transformers
148
- """
149
-
150
- @st.cache_resource
151
- def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings:
152
- return HuggingFaceInferenceAPIEmbeddings(
153
- api_key=self.HF_TOKEN,
154
- model_name="sentence-transformers/all-mpnet-base-v2",
155
- )
156
-
157
- self.embedding_model = init_embedding_model()
158
-
159
- def setup_vector_search(self) -> None:
160
- """
161
- Parameters
162
- ----------
163
- None
164
-
165
- Output
166
- ------
167
- None
168
-
169
- Purpose
170
- -------
171
- Sets up the vector search functionality using MongoDB Atlas.
172
-
173
- Assumptions
174
- -----------
175
- - MongoDB Atlas vector search index is properly configured
176
- - Valid embedding model is initialized
177
-
178
- Notes
179
- -----
180
- Creates a retriever with similarity search and score threshold
181
- """
182
-
183
- @st.cache_resource
184
- def init_vector_search() -> MongoDBAtlasVectorSearch:
185
- return MongoDBAtlasVectorSearch.from_connection_string(
186
- connection_string=self.MONGO_URI,
187
- namespace=f"{self.DB_NAME}.{self.COLLECTION_NAME}",
188
- embedding=self.embedding_model,
189
- index_name=self.VECTOR_SEARCH_INDEX,
190
- )
191
-
192
- self.vector_search = init_vector_search()
193
- self.retriever = self.vector_search.as_retriever(
194
- search_type="similarity", search_kwargs={"k": 10, "score_threshold": 0.85}
195
- )
196
-
197
- def format_docs(self, docs: list[Document]) -> str:
198
- """
199
- Parameters
200
- ----------
201
- **docs:** list[Document] - List of documents to be formatted
202
-
203
- Output
204
- ------
205
- str: Formatted string containing concatenated document content
206
-
207
- Purpose
208
- -------
209
- Formats the retrieved documents into a single string for processing
210
-
211
- Assumptions
212
- -----------
213
- Documents have page_content attribute
214
-
215
- Notes
216
- -----
217
- Joins documents with double newlines for better readability
218
- """
219
-
220
- return "\n\n".join(doc.page_content for doc in docs)
221
-
222
- def generate_response(self, input_dict: Dict[str, Any]) -> str:
223
- """
224
- Parameters
225
- ----------
226
- **input_dict:** Dict[str, Any] - Dictionary containing context and question
227
-
228
- Output
229
- ------
230
- str: Generated response from the model
231
-
232
- Purpose
233
- -------
234
- Generates a response using the Hugging Face model based on context and question
235
-
236
- Assumptions
237
- -----------
238
- - Valid Hugging Face API token
239
- - Input dictionary contains 'context' and 'question' keys
240
-
241
- Notes
242
- -----
243
- Uses Qwen2.5-1.5B-Instruct model with controlled temperature
244
- """
245
- hf_client = InferenceClient(api_key=self.HF_TOKEN)
246
- formatted_prompt = self.prompt.format(**input_dict)
247
-
248
- response = hf_client.chat.completions.create(
249
- model="Qwen/Qwen2.5-1.5B-Instruct",
250
- messages=[
251
- {"role": "system", "content": formatted_prompt},
252
- {"role": "user", "content": input_dict["question"]},
253
- ],
254
- max_tokens=1000,
255
- temperature=0.2,
256
- )
257
-
258
- return response.choices[0].message.content
259
-
260
- def setup_rag_chain(self) -> None:
261
- """
262
- Parameters
263
- ----------
264
- None
265
-
266
- Output
267
- ------
268
- None
269
-
270
- Purpose
271
- -------
272
- Sets up the RAG chain for processing questions and generating answers
273
-
274
- Assumptions
275
- -----------
276
- Retriever and response generator are properly initialized
277
-
278
- Notes
279
- -----
280
- Creates a chain that combines retrieval and response generation
281
- """
282
-
283
- self.prompt = PromptTemplate.from_template(
284
- """Use the following pieces of context to answer the question at the end.
285
-
286
- START OF CONTEXT:
287
- {context}
288
- END OF CONTEXT:
289
-
290
- START OF QUESTION:
291
- {question}
292
- END OF QUESTION:
293
-
294
- If you do not know the answer, just say that you do not know.
295
- NEVER assume things.
296
- """
297
- )
298
-
299
- self.rag_chain = {
300
- "context": self.retriever | RunnableLambda(self.format_docs),
301
- "question": RunnablePassthrough(),
302
- } | RunnableLambda(self.generate_response)
303
-
304
- def process_question(self, question: str) -> str:
305
- """
306
- Parameters
307
- ----------
308
- **question:** str - The user's question to be answered
309
-
310
- Output
311
- ------
312
- str: The generated answer to the question
313
-
314
- Purpose
315
- -------
316
- Processes a user question through the RAG chain and returns an answer
317
-
318
- Assumptions
319
- -----------
320
- - Question is a non-empty string
321
- - RAG chain is properly initialized
322
-
323
- Notes
324
- -----
325
- Main interface for question-answering functionality
326
- """
327
-
328
- return self.rag_chain.invoke(question)
329
-
330
-
331
- #############################################################################################################################
332
- def setup_streamlit_ui() -> None:
333
- """
334
- Parameters
335
- ----------
336
- None
337
-
338
- Output
339
- ------
340
- None
341
-
342
- Purpose
343
- -------
344
- Sets up the Streamlit user interface with proper styling and layout
345
-
346
- Assumptions
347
- -----------
348
- - CSS file exists at ./static/styles/style.css
349
- - Image file exists at ./static/images/ctp.png
350
-
351
- Notes
352
- -----
353
- Handles all UI-related setup and styling
354
- """
355
-
356
- st.set_page_config(page_title="RAG Question Answering", page_icon="🤖")
357
-
358
- # Load CSS.
359
- with open("./static/styles/style.css") as f:
360
- st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
361
-
362
- # Title and subtitles.
363
- st.markdown(
364
- '<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">RAG Question Answering</h1>',
365
- unsafe_allow_html=True,
366
- )
367
- st.markdown(
368
- '<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">Using Zoom Closed Captioning From The Lectures</h3>',
369
- unsafe_allow_html=True,
370
- )
371
- st.markdown(
372
- '<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">CUNY Tech Prep Tutorial 5</h2>',
373
- unsafe_allow_html=True,
374
- )
375
-
376
- # Display logo.
377
- left_co, cent_co, last_co = st.columns(3)
378
- with cent_co:
379
- st.image("./static/images/ctp.png")
380
-
381
-
382
- #############################################################################################################################
383
-
384
-
385
- def main():
386
- """
387
- Parameters
388
- ----------
389
- None
390
-
391
- Output
392
- ------
393
- None
394
-
395
- Purpose
396
- -------
397
- Main function that runs the Streamlit application
398
-
399
- Assumptions
400
- -----------
401
- All required environment variables and files are present
402
-
403
- Notes
404
- -----
405
- Entry point for the application
406
- """
407
-
408
- # Setup UI.
409
- setup_streamlit_ui()
410
-
411
- # Initialize RAG system.
412
- rag_system = RAGQuestionAnswering()
413
-
414
- # Create input elements.
415
- query = st.text_input("Question:", key="question_input")
416
-
417
- # Handle submission.
418
- if st.button("Submit", type="primary"):
419
- if query:
420
- with st.spinner("Generating response..."):
421
- response = rag_system.process_question(query)
422
- st.text_area("Answer:", value=response, height=200, disabled=True)
423
- else:
424
- st.warning("Please enter a question.")
425
-
426
- # Add GitHub link.
427
- st.markdown(
428
- """
429
- <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;">
430
- <b>Check out our <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;">GitHub repository</a></b>
431
- </p>
432
- """,
433
- unsafe_allow_html=True,
434
- )
435
-
436
-
437
- #############################################################################################################################
438
- if __name__ == "__main__":
439
- main()
 
1
+ !pip install gradio pymongo langchain-community transformers
2
+
3
+ # Import libraries.
4
+ # Gradio.
5
+ import gradio as gr
6
+
7
+ # File loading and environment variables.
8
+ import os
9
+ import sys
10
+
11
+ # Gradio.
12
+ from gradio.themes.base import Base
13
+
14
+ # HuggingFace LLM.
15
+ from huggingface_hub import InferenceClient
16
+
17
+ # Langchain.
18
+ from langchain.document_loaders import TextLoader
19
+ from langchain.prompts import PromptTemplate
20
+ from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
+ from langchain_community.vectorstores import MongoDBAtlasVectorSearch
23
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
24
+
25
+ # MongoDB.
26
+ from pymongo import MongoClient
27
+
28
+ # Function type hints.
29
+ from typing import Dict, Any
30
+
31
+ #Secrets
32
+ from kaggle_secrets import UserSecretsClient
33
+
34
+ directory_path = "/kaggle/input/rag-dataset/RAG"
35
+ sys.path.append(directory_path)
36
+ print("sys.path =", sys.path)
37
+
38
+ my_txts = os.listdir(directory_path)
39
+ my_txts
40
+
41
+ loaders = []
42
+ for my_txt in my_txts:
43
+ my_txt_path = os.path.join(directory_path, my_txt)
44
+ text_loader = TextLoader(my_txt_path)
45
+ loaders.append(text_loader)
46
+
47
+ print("len(loaders) =", len(loaders))
48
+
49
+ loaders
50
+
51
+ # Load the TXT.
52
+
53
+ data = []
54
+ for loader in loaders:
55
+ loaded_text = loader.load()
56
+ data.append(loaded_text)
57
+
58
+ print("len(data) =", len(data), "\n")
59
+
60
+ # First TXT file.
61
+ data[0]
62
+
63
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
64
+
65
+ docs = []
66
+ for doc in data:
67
+ chunk = text_splitter.split_documents(doc)
68
+ docs.append(chunk)
69
+
70
+ merged_documents = []
71
+
72
+ for doc in docs:
73
+ merged_documents.extend(doc)
74
+
75
+ # Print the merged list of all the documents.
76
+ print("len(merged_documents) =", len(merged_documents))
77
+ print(merged_documents)
78
+
79
+ # Connect to MongoDB Atlas cluster using the connection string.
80
+ from kaggle_secrets import UserSecretsClient
81
+ user_secrets = UserSecretsClient()
82
+ # secret_value_0= user_secrets.get_secret("MONGO_URI")
83
+ MONGO_URI = user_secrets.get_secret("MONGO_URI")
84
+ cluster = MongoClient(MONGO_URI)
85
+
86
+ # Define the MongoDB database and collection name.
87
+ DB_NAME = "files"
88
+ COLLECTION_NAME = "files_collection"
89
+
90
+ # Connect to the specific collection in the database.
91
+ MONGODB_COLLECTION = cluster[DB_NAME][COLLECTION_NAME]
92
+ vector_search_index = "vector_index"
93
+
94
+ from kaggle_secrets import UserSecretsClient
95
+ user_secrets = UserSecretsClient()
96
+ HF_TOKEN = user_secrets.get_secret("hugging_face")
97
+ embedding_model = HuggingFaceInferenceAPIEmbeddings(
98
+ api_key=HF_TOKEN, model_name="sentence-transformers/all-mpnet-base-v2"
99
+ )
100
+
101
+ # #populated mongo_db
102
+ # vector_search = MongoDBAtlasVectorSearch.from_documents(
103
+ # documents=merged_documents,
104
+ # embedding=embedding_model,
105
+ # collection=MONGODB_COLLECTION,
106
+ # index_name=vector_search_index
107
+ # )
108
+
109
+ vector_search = MongoDBAtlasVectorSearch.from_connection_string(
110
+ connection_string=MONGO_URI,
111
+ namespace=f"{DB_NAME}.{COLLECTION_NAME}",
112
+ embedding=embedding_model,
113
+ index_name=vector_search_index,
114
+ )
115
+
116
+ query = "why EfficientNetB0?"
117
+ results = vector_search.similarity_search(query=query, k=25) # 25 most similar documents.
118
+
119
+ print("\n")
120
+ print(results)
121
+
122
+ # k to search for only the X most relevant documents.
123
+ k = 10
124
+
125
+ # score_threshold to use only documents with a relevance score above 0.80.
126
+ score_threshold = 0.80
127
+
128
+ # Build your retriever
129
+ retriever_1 = vector_search.as_retriever(
130
+ search_type = "similarity", # similarity, mmr, similarity_score_threshold. https://api.python.langchain.com/en/latest/vectorstores/langchain_core.vectorstores.VectorStore.html#langchain_core.vectorstores.VectorStore.as_retriever
131
+ search_kwargs = {"k": k, "score_threshold": score_threshold}
132
+ )
133
+
134
+
135
+
136
+ # Initialize Hugging Face client
137
+ hf_client = InferenceClient(api_key=HF_TOKEN)
138
+
139
+ # Define the prompt template
140
+ prompt = PromptTemplate.from_template(
141
+ """Use the following pieces of context to answer the question at the end.
142
+
143
+ START OF CONTEXT:
144
+ {context}
145
+ END OF CONTEXT:
146
+
147
+ START OF QUESTION:
148
+ {question}
149
+ END OF QUESTION:
150
+
151
+ If you do not know the answer, just say that you do not know.
152
+ NEVER assume things.
153
+ """
154
+ )
155
+
156
+ def format_docs(docs):
157
+ return "\n\n".join(doc.page_content for doc in docs)
158
+
159
+
160
+ def generate_response(input_dict: Dict[str, Any]) -> str:
161
+ formatted_prompt = prompt.format(**input_dict)
162
+ # print(formatted_prompt)
163
+
164
+ ## THIS IS YOUR LLM
165
+ response = hf_client.chat.completions.create(
166
+ model="Qwen/Qwen2.5-1.5B-Instruct",
167
+ messages=[{
168
+ "role": "system",
169
+ "content": formatted_prompt
170
+ },{
171
+ "role": "user",
172
+ "content": input_dict["question"]
173
+ }],
174
+ max_tokens=1000,
175
+ temperature=0.2,
176
+ )
177
+
178
+ return response.choices[0].message.content
179
+
180
+ rag_chain = (
181
+ {
182
+ "context": retriever_1 | RunnableLambda(format_docs),
183
+ "question": RunnablePassthrough()
184
+ }
185
+ | RunnableLambda(generate_response)
186
+ )
187
+
188
+
189
+ query = "what is scaling?"
190
+ answer = rag_chain.invoke(query)
191
+
192
+ print("\nQuestion:", query)
193
+ print("Answer:", answer)
194
+
195
+ # Get source documents related to the query.
196
+ documents = retriever_1.invoke(query)
197
+
198
+ # print("\nSource documents:")
199
+ # from pprint import pprint
200
+ # pprint(results)
201
+
202
+
203
+
204
+ query = "How the GUI was implemented?"
205
+ answer = rag_chain.invoke(query)
206
+
207
+ print("\nQuestion:", query)
208
+ print("Answer:", answer)
209
+
210
+ # Get source documents related to the query.
211
+ documents = retriever_1.invoke(query)
212
+
213
+ print("\nSource documents:")
214
+ from pprint import pprint
215
+ pprint(results)
216
+
217
+ query = "How the GUI was implemented?"
218
+ answer = rag_chain.invoke(query)
219
+
220
+ print("\nQuestion:", query)
221
+ print("Answer:", answer)
222
+
223
+ # Get source documents related to the query.
224
+ documents = retriever_1.invoke(query)
225
+ formatted_docs = format_docs(documents)
226
+ print("\nSource Documents:\n", formatted_docs)