pdf_rag_chatbot / app.py
hohieu's picture
remove open API key
73f8358
import streamlit as st
from pyvi.ViTokenizer import tokenize
from src.services.generate_embedding import generate_embedding
import pymongo
import time
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
import os
# Connect DB
client = pymongo.MongoClient(
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG"
)
db = client.rag
collection = db.pdf
def stream_response(answer: str):
for word in answer.split(" "):
yield word + " "
time.sleep(0.03)
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"], unsafe_allow_html=True)
def retriveByIndex(idxs):
docs = collection.find({"index": {"$in": idxs}})
content = ""
for doc in docs:
content = content + " " + doc["page_content"]
return content
def generateAnswer(context: str, question: str):
prompt = ChatPromptTemplate.from_messages(
[
(
"user","""Trả lời câu hỏi của người dùng dựa vào thông tin có trong thẻ <context> </context> được cho bên dưới. Nếu context không chứa những thông tin liên quan tới câu hỏi, thì đừng trả lời và chỉ trả lời là "Tôi không biết". <context> {context} </context> Câu hỏi: {question}""",
),
]
)
messages = prompt.invoke({"context": context, "question": question});
print(messages)
chat = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0.8)
response = chat.invoke(messages)
return response.content
# React to user input
if prompt := st.chat_input(""):
tokenized_prompt = tokenize(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
embedding = generate_embedding(tokenized_prompt)
results = collection.aggregate(
[
{
"$vectorSearch": {
"queryVector": embedding,
"path": "page_content_embedding",
"numCandidates": 5,
"limit": 5,
"index": "vector_index",
}
}
]
)
allIndx = []
for document in results:
idx = document["index"]
allIndx.append(idx)
allIndx.append(idx + 1)
allIndx.append(idx + 2)
allIndx.append(idx + 3)
print(allIndx)
context = retriveByIndex(allIndx)
answer = generateAnswer(context, question=prompt)
with st.chat_message("assistant"):
st.markdown(answer, unsafe_allow_html=True)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": answer})