droushb commited on
Commit
dc75be1
1 Parent(s): a95bcbf

Initial commit for RAG Question Answering system

Browse files
Files changed (5) hide show
  1. app.py +71 -0
  2. config.py +14 -0
  3. model/main.py +44 -0
  4. model/questionAnsweringBot.py +28 -0
  5. model/retriever.py +42 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from config import CONFIG
3
+ from model.main import process_query
4
+
5
+ st.title("RAG Question Answering System")
6
+
7
+ # Instructions
8
+ st.write("""
9
+ Welcome to the Retrieval-Augmented Generation (RAG) Question Answering System.
10
+
11
+ ### What does this system do?
12
+ - Searches through a collection of the first 50,000 documents of the dataset to find the most relevant information based on your question using **BM25** and **Semantic Search**.
13
+ - Generates accurate answers using the retrieved documents with the power of **OpenAI API GPT-4o-mini**.
14
+ - Provides citations for every piece of information to ensure transparency and trustworthiness.
15
+
16
+ ### Instructions
17
+ 1. **Enter your OpenAI API Key**: You can use your own key.
18
+ 2. **Ask Your Question**: Type your question in the input box.
19
+ 3. **Choose a Retrieval Method**:
20
+ - **BM25**: A keyword-based retrieval method.
21
+ - **Semantic Search**: A context-based retrieval method powered by embeddings.
22
+ 4. **Generate the Answer**: Click the "Generate Answer" button to retrieve relevant documents and generate a detailed answer.
23
+
24
+ Feel free to experiment with different questions and retrieval methods to explore how the system performs!
25
+ """)
26
+
27
+ llm_key = st.text_input("Enter your LLM API Key", type="password")
28
+ # if st.checkbox("Use Test API Key"):
29
+ # llm_key = CONFIG['LLM_API_key']
30
+ if not llm_key:
31
+ st.warning("Please provide your LLM API Key to proceed.")
32
+ st.stop()
33
+
34
+ query = st.text_input("Enter your question")
35
+ retrieval_method = st.radio(
36
+ "Select Retrieval Method",
37
+ ("BM25", "Semantic Search")
38
+ )
39
+
40
+ if st.button("Generate Answear"):
41
+ if not query.strip():
42
+ st.warning("Please enter a question to process.")
43
+ else:
44
+ with st.spinner("Processing your query..."):
45
+ try:
46
+ retrieved_docs, answer = process_query(llm_key, query, retrieval_method)
47
+
48
+ st.subheader("Retrieved Documents")
49
+ for doc in retrieved_docs:
50
+ st.write(f"- {doc}")
51
+
52
+ st.subheader("Generated Answer")
53
+ st.text_area("Generated Answer", value=answer, height=CONFIG['TEXTAREA_HEIGHT'], disabled=True)
54
+ except Exception as e:
55
+ st.error(f"An error occurred: {e}")
56
+
57
+ st.markdown(
58
+ """
59
+ <style>
60
+ .stTextArea {
61
+ border: 2px solid #4CAF50;
62
+ border-radius: 8px;
63
+ padding: 10px;
64
+ font-family: Arial, sans-serif;
65
+ font-size: 14px;
66
+ box-shadow: 2px 2px 5px rgba(0, 0, 0, 0.1);
67
+ }
68
+ </style>
69
+ """,
70
+ unsafe_allow_html=True
71
+ )
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ CONFIG = {
7
+ "DATASET": "aalksii/ml-arxiv-papers",
8
+ "MAX_NUM_OF_RECORDS": 1000,
9
+ "TEXTAREA_HEIGHT": 200,
10
+ "CHUNK_SIZE": 200,
11
+ "OPENAI_ENGINE": "gpt-4o-mini",
12
+ "MAX_TOKENS": 500,
13
+ "TOP_DOCS": 3
14
+ }
model/main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from model.questionAnsweringBot import QuestionAnsweringBot
3
+ from model.retriever import Retriever
4
+
5
+ def process_query(llm_key, query, retrieval_method):
6
+ if "retriever" not in st.session_state:
7
+ st.session_state.retriever = Retriever()
8
+ print("Loading and preparing dataset...")
9
+ st.session_state.retriever.load_and_prepare_dataset()
10
+ st.session_state.retriever.prepare_bm25()
11
+ st.session_state.retriever.compute_embeddings()
12
+
13
+ retriever = st.session_state.retriever
14
+
15
+ if retrieval_method == "BM25":
16
+ print("Retrieving documents using BM25...")
17
+ retrieved_docs = retriever.retrieve_documents_bm25(query)
18
+ else:
19
+ print("Retrieving documents using Semantic Search...")
20
+ retrieved_docs = retriever.retrieve_documents_semantic(query)
21
+
22
+ bot = QuestionAnsweringBot(llm_key)
23
+ prompt = getPrompt(retrieved_docs, query)
24
+ answer = bot.generate_answer(prompt)
25
+
26
+ return retrieved_docs, answer
27
+
28
+ def getPrompt(retrieved_docs, query):
29
+ prompt = (
30
+ "You are an LM integrated into an RAG system that answers questions based on provided documents.\n"
31
+ "Rules:\n"
32
+ "- Reply with the answer only and nothing but the answer.\n"
33
+ "- Say 'I don't know' if you don't know the answer.\n"
34
+ "- Use only the provided documents.\n"
35
+ "- Citations are required. Include the document and chunk number in square brackets after the information (e.g., [Document 1, Chunk 2]).\n\n"
36
+ "Documents:\n"
37
+ )
38
+
39
+ for i, doc in enumerate(retrieved_docs):
40
+ prompt += f"Document {i + 1}: {doc}\n"
41
+
42
+ prompt += f"\nQuery: {query}\n"
43
+
44
+ return prompt
model/questionAnsweringBot.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from config import CONFIG
3
+
4
+ PROMPT = """
5
+ You are a helpful assistant that can answer questions.
6
+ Rules:
7
+ - Reply with the answer only and nothing but the answer.
8
+ - Say "I don't know" if you don't know the answer.
9
+ - Use the provided context.
10
+ """
11
+
12
+ class QuestionAnsweringBot:
13
+ def __init__(self, llm_key):
14
+ openai.api_key = llm_key
15
+
16
+ def generate_answer(self, prompt):
17
+ try:
18
+ completion = openai.ChatCompletion.create(
19
+ model=CONFIG['OPENAI_ENGINE'],
20
+ messages=[
21
+ # {"role": "system", "content": PROMPT},
22
+ {"role": "user", "content": prompt}
23
+ ],
24
+ max_tokens=CONFIG['MAX_TOKENS'],
25
+ )
26
+ return completion['choices'][0]['message']['content'].strip()
27
+ except Exception as e:
28
+ return f"An error occurred while generating the answer: {e}"
model/retriever.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from config import CONFIG
3
+ from rank_bm25 import BM25Okapi
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+
7
+ class Retriever:
8
+ def __init__(self):
9
+ self.corpus = None
10
+ self.bm25 = None
11
+ self.model = None
12
+ self.chunk_embeddings = None
13
+
14
+ def load_and_prepare_dataset(self):
15
+ dataset = load_dataset(CONFIG['DATASET'])
16
+ dataset = dataset['train'].select(range(CONFIG['MAX_NUM_OF_RECORDS']))
17
+ dataset = dataset.map(lambda x: {'chunks': self.chunk_text(x['abstract'])})
18
+ self.corpus = [chunk for chunks in dataset["chunks"] for chunk in chunks]
19
+
20
+ def prepare_bm25(self):
21
+ tokenized_corpus = [doc.split(" ") for doc in self.corpus]
22
+ self.bm25 = BM25Okapi(tokenized_corpus)
23
+
24
+ def compute_embeddings(self):
25
+ self.model = SentenceTransformer('all-MiniLM-L6-v2')
26
+ self.chunk_embeddings = self.model.encode(self.corpus, convert_to_tensor=True)
27
+
28
+ def chunk_text(self, text, chunk_size=CONFIG['CHUNK_SIZE']):
29
+ words = text.split()
30
+ return [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
31
+
32
+ def retrieve_documents_bm25(self, query):
33
+ tokenized_query = query.split(" ")
34
+ scores = self.bm25.get_scores(tokenized_query)
35
+ top_docs = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:CONFIG['TOP_DOCS']]
36
+ return [self.corpus[i] for i in top_docs]
37
+
38
+ def retrieve_documents_semantic(self, query):
39
+ query_embedding = self.model.encode(query, convert_to_tensor=True)
40
+ scores = util.pytorch_cos_sim(query_embedding, self.chunk_embeddings)[0]
41
+ top_chunks = scores.topk(CONFIG['TOP_DOCS']).indices
42
+ return [self.corpus[i] for i in top_chunks]