Initial commit for RAG Question Answering system
Browse files- app.py +71 -0
- config.py +14 -0
- model/main.py +44 -0
- model/questionAnsweringBot.py +28 -0
- model/retriever.py +42 -0
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from config import CONFIG
|
3 |
+
from model.main import process_query
|
4 |
+
|
5 |
+
st.title("RAG Question Answering System")
|
6 |
+
|
7 |
+
# Instructions
|
8 |
+
st.write("""
|
9 |
+
Welcome to the Retrieval-Augmented Generation (RAG) Question Answering System.
|
10 |
+
|
11 |
+
### What does this system do?
|
12 |
+
- Searches through a collection of the first 50,000 documents of the dataset to find the most relevant information based on your question using **BM25** and **Semantic Search**.
|
13 |
+
- Generates accurate answers using the retrieved documents with the power of **OpenAI API GPT-4o-mini**.
|
14 |
+
- Provides citations for every piece of information to ensure transparency and trustworthiness.
|
15 |
+
|
16 |
+
### Instructions
|
17 |
+
1. **Enter your OpenAI API Key**: You can use your own key.
|
18 |
+
2. **Ask Your Question**: Type your question in the input box.
|
19 |
+
3. **Choose a Retrieval Method**:
|
20 |
+
- **BM25**: A keyword-based retrieval method.
|
21 |
+
- **Semantic Search**: A context-based retrieval method powered by embeddings.
|
22 |
+
4. **Generate the Answer**: Click the "Generate Answer" button to retrieve relevant documents and generate a detailed answer.
|
23 |
+
|
24 |
+
Feel free to experiment with different questions and retrieval methods to explore how the system performs!
|
25 |
+
""")
|
26 |
+
|
27 |
+
llm_key = st.text_input("Enter your LLM API Key", type="password")
|
28 |
+
# if st.checkbox("Use Test API Key"):
|
29 |
+
# llm_key = CONFIG['LLM_API_key']
|
30 |
+
if not llm_key:
|
31 |
+
st.warning("Please provide your LLM API Key to proceed.")
|
32 |
+
st.stop()
|
33 |
+
|
34 |
+
query = st.text_input("Enter your question")
|
35 |
+
retrieval_method = st.radio(
|
36 |
+
"Select Retrieval Method",
|
37 |
+
("BM25", "Semantic Search")
|
38 |
+
)
|
39 |
+
|
40 |
+
if st.button("Generate Answear"):
|
41 |
+
if not query.strip():
|
42 |
+
st.warning("Please enter a question to process.")
|
43 |
+
else:
|
44 |
+
with st.spinner("Processing your query..."):
|
45 |
+
try:
|
46 |
+
retrieved_docs, answer = process_query(llm_key, query, retrieval_method)
|
47 |
+
|
48 |
+
st.subheader("Retrieved Documents")
|
49 |
+
for doc in retrieved_docs:
|
50 |
+
st.write(f"- {doc}")
|
51 |
+
|
52 |
+
st.subheader("Generated Answer")
|
53 |
+
st.text_area("Generated Answer", value=answer, height=CONFIG['TEXTAREA_HEIGHT'], disabled=True)
|
54 |
+
except Exception as e:
|
55 |
+
st.error(f"An error occurred: {e}")
|
56 |
+
|
57 |
+
st.markdown(
|
58 |
+
"""
|
59 |
+
<style>
|
60 |
+
.stTextArea {
|
61 |
+
border: 2px solid #4CAF50;
|
62 |
+
border-radius: 8px;
|
63 |
+
padding: 10px;
|
64 |
+
font-family: Arial, sans-serif;
|
65 |
+
font-size: 14px;
|
66 |
+
box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.1);
|
67 |
+
}
|
68 |
+
</style>
|
69 |
+
""",
|
70 |
+
unsafe_allow_html=True
|
71 |
+
)
|
config.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
load_dotenv()
|
5 |
+
|
6 |
+
CONFIG = {
|
7 |
+
"DATASET": "aalksii/ml-arxiv-papers",
|
8 |
+
"MAX_NUM_OF_RECORDS": 1000,
|
9 |
+
"TEXTAREA_HEIGHT": 200,
|
10 |
+
"CHUNK_SIZE": 200,
|
11 |
+
"OPENAI_ENGINE": "gpt-4o-mini",
|
12 |
+
"MAX_TOKENS": 500,
|
13 |
+
"TOP_DOCS": 3
|
14 |
+
}
|
model/main.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from model.questionAnsweringBot import QuestionAnsweringBot
|
3 |
+
from model.retriever import Retriever
|
4 |
+
|
5 |
+
def process_query(llm_key, query, retrieval_method):
|
6 |
+
if "retriever" not in st.session_state:
|
7 |
+
st.session_state.retriever = Retriever()
|
8 |
+
print("Loading and preparing dataset...")
|
9 |
+
st.session_state.retriever.load_and_prepare_dataset()
|
10 |
+
st.session_state.retriever.prepare_bm25()
|
11 |
+
st.session_state.retriever.compute_embeddings()
|
12 |
+
|
13 |
+
retriever = st.session_state.retriever
|
14 |
+
|
15 |
+
if retrieval_method == "BM25":
|
16 |
+
print("Retrieving documents using BM25...")
|
17 |
+
retrieved_docs = retriever.retrieve_documents_bm25(query)
|
18 |
+
else:
|
19 |
+
print("Retrieving documents using Semantic Search...")
|
20 |
+
retrieved_docs = retriever.retrieve_documents_semantic(query)
|
21 |
+
|
22 |
+
bot = QuestionAnsweringBot(llm_key)
|
23 |
+
prompt = getPrompt(retrieved_docs, query)
|
24 |
+
answer = bot.generate_answer(prompt)
|
25 |
+
|
26 |
+
return retrieved_docs, answer
|
27 |
+
|
28 |
+
def getPrompt(retrieved_docs, query):
|
29 |
+
prompt = (
|
30 |
+
"You are an LM integrated into an RAG system that answers questions based on provided documents.\n"
|
31 |
+
"Rules:\n"
|
32 |
+
"- Reply with the answer only and nothing but the answer.\n"
|
33 |
+
"- Say 'I don't know' if you don't know the answer.\n"
|
34 |
+
"- Use only the provided documents.\n"
|
35 |
+
"- Citations are required. Include the document and chunk number in square brackets after the information (e.g., [Document 1, Chunk 2]).\n\n"
|
36 |
+
"Documents:\n"
|
37 |
+
)
|
38 |
+
|
39 |
+
for i, doc in enumerate(retrieved_docs):
|
40 |
+
prompt += f"Document {i + 1}: {doc}\n"
|
41 |
+
|
42 |
+
prompt += f"\nQuery: {query}\n"
|
43 |
+
|
44 |
+
return prompt
|
model/questionAnsweringBot.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from config import CONFIG
|
3 |
+
|
4 |
+
PROMPT = """
|
5 |
+
You are a helpful assistant that can answer questions.
|
6 |
+
Rules:
|
7 |
+
- Reply with the answer only and nothing but the answer.
|
8 |
+
- Say "I don't know" if you don't know the answer.
|
9 |
+
- Use the provided context.
|
10 |
+
"""
|
11 |
+
|
12 |
+
class QuestionAnsweringBot:
|
13 |
+
def __init__(self, llm_key):
|
14 |
+
openai.api_key = llm_key
|
15 |
+
|
16 |
+
def generate_answer(self, prompt):
|
17 |
+
try:
|
18 |
+
completion = openai.ChatCompletion.create(
|
19 |
+
model=CONFIG['OPENAI_ENGINE'],
|
20 |
+
messages=[
|
21 |
+
# {"role": "system", "content": PROMPT},
|
22 |
+
{"role": "user", "content": prompt}
|
23 |
+
],
|
24 |
+
max_tokens=CONFIG['MAX_TOKENS'],
|
25 |
+
)
|
26 |
+
return completion['choices'][0]['message']['content'].strip()
|
27 |
+
except Exception as e:
|
28 |
+
return f"An error occurred while generating the answer: {e}"
|
model/retriever.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from config import CONFIG
|
3 |
+
from rank_bm25 import BM25Okapi
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
|
6 |
+
|
7 |
+
class Retriever:
|
8 |
+
def __init__(self):
|
9 |
+
self.corpus = None
|
10 |
+
self.bm25 = None
|
11 |
+
self.model = None
|
12 |
+
self.chunk_embeddings = None
|
13 |
+
|
14 |
+
def load_and_prepare_dataset(self):
|
15 |
+
dataset = load_dataset(CONFIG['DATASET'])
|
16 |
+
dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS']))
|
17 |
+
dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])})
|
18 |
+
self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks]
|
19 |
+
|
20 |
+
def prepare_bm25(self):
|
21 |
+
tokenized_corpus = [doc.split(" ") for doc in self.corpus]
|
22 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
23 |
+
|
24 |
+
def compute_embeddings(self):
|
25 |
+
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
26 |
+
self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True)
|
27 |
+
|
28 |
+
def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']):
|
29 |
+
words = text.split()
|
30 |
+
return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
|
31 |
+
|
32 |
+
def retrieve_documents_bm25(self, query):
|
33 |
+
tokenized_query = query.split(" ")
|
34 |
+
scores = self.bm25.get_scores(tokenized_query)
|
35 |
+
top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']]
|
36 |
+
return [self.corpus[i] for i in top_docs]
|
37 |
+
|
38 |
+
def retrieve_documents_semantic(self, query):
|
39 |
+
query_embedding = self.model.encode(query, convert_to_tensor=True)
|
40 |
+
scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0]
|
41 |
+
top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices
|
42 |
+
return [self.corpus[i] for i in top_chunks]
|