|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import streamlit as st |
|
|
|
from dotenv import load_dotenv, find_dotenv |
|
from huggingface_hub import InferenceClient |
|
from langchain.prompts import PromptTemplate |
|
from langchain.schema import Document |
|
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
from langchain_community.vectorstores import MongoDBAtlasVectorSearch |
|
from pymongo import MongoClient |
|
from pymongo.collection import Collection |
|
from typing import Dict, Any |
|
|
|
|
|
|
|
|
|
|
|
class RAGQuestionAnswering: |
|
def __init__(self): |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Initializes the RAG Question Answering system by setting up configuration |
|
and loading environment variables. |
|
|
|
Assumptions |
|
----------- |
|
- Expects .env file with MONGO_URI and HF_TOKEN |
|
- Requires proper MongoDB setup with vector search index |
|
- Needs connection to Hugging Face API |
|
|
|
Notes |
|
----- |
|
This is the main class that handles all RAG operations |
|
""" |
|
self.load_environment() |
|
self.setup_mongodb() |
|
self.setup_embedding_model() |
|
self.setup_vector_search() |
|
self.setup_rag_chain() |
|
|
|
def load_environment(self) -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Loads environment variables from .env file and sets up configuration constants. |
|
|
|
Assumptions |
|
----------- |
|
Expects a .env file with MONGO_URI and HF_TOKEN defined |
|
|
|
Notes |
|
----- |
|
Will stop the application if required environment variables are missing |
|
""" |
|
|
|
load_dotenv(find_dotenv()) |
|
self.MONGO_URI = os.getenv("MONGO_URI") |
|
self.HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
if not self.MONGO_URI or not self.HF_TOKEN: |
|
st.error("Please ensure MONGO_URI and HF_TOKEN are set in your .env file") |
|
st.stop() |
|
|
|
|
|
self.DB_NAME = "files" |
|
self.COLLECTION_NAME = "files_collection" |
|
self.VECTOR_SEARCH_INDEX = "vector_index" |
|
|
|
def setup_mongodb(self) -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Initializes the MongoDB connection and sets up the collection. |
|
|
|
Assumptions |
|
----------- |
|
- Valid MongoDB URI is available |
|
- Database and collection exist in MongoDB Atlas |
|
|
|
Notes |
|
----- |
|
Uses st.cache_resource for efficient connection management |
|
""" |
|
|
|
@st.cache_resource |
|
def init_mongodb() -> Collection: |
|
cluster = MongoClient(self.MONGO_URI) |
|
return cluster[self.DB_NAME][self.COLLECTION_NAME] |
|
|
|
self.mongodb_collection = init_mongodb() |
|
|
|
def setup_embedding_model(self) -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Initializes the embedding model for vector search. |
|
|
|
Assumptions |
|
----------- |
|
- Valid Hugging Face API token |
|
- Internet connection to access the model |
|
|
|
Notes |
|
----- |
|
Uses the all-mpnet-base-v2 model from sentence-transformers |
|
""" |
|
|
|
@st.cache_resource |
|
def init_embedding_model() -> HuggingFaceInferenceAPIEmbeddings: |
|
return HuggingFaceInferenceAPIEmbeddings( |
|
api_key=self.HF_TOKEN, |
|
model_name="sentence-transformers/all-mpnet-base-v2", |
|
) |
|
|
|
self.embedding_model = init_embedding_model() |
|
|
|
def setup_vector_search(self) -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Sets up the vector search functionality using MongoDB Atlas. |
|
|
|
Assumptions |
|
----------- |
|
- MongoDB Atlas vector search index is properly configured |
|
- Valid embedding model is initialized |
|
|
|
Notes |
|
----- |
|
Creates a retriever with similarity search and score threshold |
|
""" |
|
|
|
@st.cache_resource |
|
def init_vector_search() -> MongoDBAtlasVectorSearch: |
|
return MongoDBAtlasVectorSearch.from_connection_string( |
|
connection_string=self.MONGO_URI, |
|
namespace=f"{self.DB_NAME}.{self.COLLECTION_NAME}", |
|
embedding=self.embedding_model, |
|
index_name=self.VECTOR_SEARCH_INDEX, |
|
) |
|
|
|
self.vector_search = init_vector_search() |
|
self.retriever = self.vector_search.as_retriever( |
|
search_type="similarity", search_kwargs={"k": 10, "score_threshold": 0.85} |
|
) |
|
|
|
def format_docs(self, docs: list[Document]) -> str: |
|
""" |
|
Parameters |
|
---------- |
|
**docs:** list[Document] - List of documents to be formatted |
|
|
|
Output |
|
------ |
|
str: Formatted string containing concatenated document content |
|
|
|
Purpose |
|
------- |
|
Formats the retrieved documents into a single string for processing |
|
|
|
Assumptions |
|
----------- |
|
Documents have page_content attribute |
|
|
|
Notes |
|
----- |
|
Joins documents with double newlines for better readability |
|
""" |
|
|
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
def generate_response(self, input_dict: Dict[str, Any]) -> str: |
|
""" |
|
Parameters |
|
---------- |
|
**input_dict:** Dict[str, Any] - Dictionary containing context and question |
|
|
|
Output |
|
------ |
|
str: Generated response from the model |
|
|
|
Purpose |
|
------- |
|
Generates a response using the Hugging Face model based on context and question |
|
|
|
Assumptions |
|
----------- |
|
- Valid Hugging Face API token |
|
- Input dictionary contains 'context' and 'question' keys |
|
|
|
Notes |
|
----- |
|
Uses Qwen2.5-1.5B-Instruct model with controlled temperature |
|
""" |
|
hf_client = InferenceClient(api_key=self.HF_TOKEN) |
|
formatted_prompt = self.prompt.format(**input_dict) |
|
|
|
response = hf_client.chat.completions.create( |
|
model="Qwen/Qwen2.5-1.5B-Instruct", |
|
messages=[ |
|
{"role": "system", "content": formatted_prompt}, |
|
{"role": "user", "content": input_dict["question"]}, |
|
], |
|
max_tokens=1000, |
|
temperature=0.2, |
|
) |
|
|
|
return response.choices[0].message.content |
|
|
|
def setup_rag_chain(self) -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Sets up the RAG chain for processing questions and generating answers |
|
|
|
Assumptions |
|
----------- |
|
Retriever and response generator are properly initialized |
|
|
|
Notes |
|
----- |
|
Creates a chain that combines retrieval and response generation |
|
""" |
|
|
|
self.prompt = PromptTemplate.from_template( |
|
"""Use the following pieces of context to answer the question at the end. |
|
|
|
START OF CONTEXT: |
|
{context} |
|
END OF CONTEXT: |
|
|
|
START OF QUESTION: |
|
{question} |
|
END OF QUESTION: |
|
|
|
If you do not know the answer, just say that you do not know. |
|
NEVER assume things. |
|
""" |
|
) |
|
|
|
self.rag_chain = { |
|
"context": self.retriever | RunnableLambda(self.format_docs), |
|
"question": RunnablePassthrough(), |
|
} | RunnableLambda(self.generate_response) |
|
|
|
def process_question(self, question: str) -> str: |
|
""" |
|
Parameters |
|
---------- |
|
**question:** str - The user's question to be answered |
|
|
|
Output |
|
------ |
|
str: The generated answer to the question |
|
|
|
Purpose |
|
------- |
|
Processes a user question through the RAG chain and returns an answer |
|
|
|
Assumptions |
|
----------- |
|
- Question is a non-empty string |
|
- RAG chain is properly initialized |
|
|
|
Notes |
|
----- |
|
Main interface for question-answering functionality |
|
""" |
|
|
|
return self.rag_chain.invoke(question) |
|
|
|
|
|
|
|
def setup_streamlit_ui() -> None: |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Sets up the Streamlit user interface with proper styling and layout |
|
|
|
Assumptions |
|
----------- |
|
- CSS file exists at ./static/styles/style.css |
|
- Image file exists at ./static/images/ctp.png |
|
|
|
Notes |
|
----- |
|
Handles all UI-related setup and styling |
|
""" |
|
|
|
st.set_page_config(page_title="RAG Question Answering", page_icon="🤖") |
|
|
|
|
|
with open("./static/styles/style.css") as f: |
|
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
|
|
|
st.markdown( |
|
'<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">RAG Question Answering</h1>', |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown( |
|
'<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">Using Documents and Research</h3>', |
|
unsafe_allow_html=True, |
|
) |
|
st.markdown( |
|
'<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">Digital Detectives: AI VS Real Images</h2>', |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
left_co, cent_co, last_co = st.columns(3) |
|
with cent_co: |
|
st.image("./static/images/poster.jpg") |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
""" |
|
Parameters |
|
---------- |
|
None |
|
|
|
Output |
|
------ |
|
None |
|
|
|
Purpose |
|
------- |
|
Main function that runs the Streamlit application |
|
|
|
Assumptions |
|
----------- |
|
All required environment variables and files are present |
|
|
|
Notes |
|
----- |
|
Entry point for the application |
|
""" |
|
|
|
|
|
setup_streamlit_ui() |
|
|
|
|
|
rag_system = RAGQuestionAnswering() |
|
|
|
|
|
query = st.text_input("Question:", key="question_input") |
|
|
|
|
|
if st.button("Submit", type="primary"): |
|
if query: |
|
with st.spinner("Generating response..."): |
|
response = rag_system.process_question(query) |
|
st.text_area("Answer:", value=response, height=200, disabled=True) |
|
else: |
|
st.warning("Please enter a question.") |
|
|
|
|
|
st.markdown( |
|
""" |
|
<p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"> |
|
<b>Check out our <a href="https://github.com/KeiraJames/CTP-Project-2024/tree/main" style="color: #FAF9F6;">GitHub repository</a></b> |
|
</p> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |