Spaces:
Runtime error
Runtime error
import streamlit as st | |
from pyvi.ViTokenizer import tokenize | |
from src.services.generate_embedding import generate_embedding | |
import pymongo | |
import time | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
import os | |
# Connect DB | |
client = pymongo.MongoClient( | |
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG" | |
) | |
db = client.rag | |
collection = db.pdf | |
def stream_response(answer: str): | |
for word in answer.split(" "): | |
yield word + " " | |
time.sleep(0.03) | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"], unsafe_allow_html=True) | |
def retriveByIndex(idxs): | |
docs = collection.find({"index": {"$in": idxs}}) | |
content = "" | |
for doc in docs: | |
content = content + " " + doc["page_content"] | |
return content | |
def generateAnswer(context: str, question: str): | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"user","""Trả lời câu hỏi của người dùng dựa vào thông tin có trong thẻ <context> </context> được cho bên dưới. Nếu context không chứa những thông tin liên quan tới câu hỏi, thì đừng trả lời và chỉ trả lời là "Tôi không biết". <context> {context} </context> Câu hỏi: {question}""", | |
), | |
] | |
) | |
messages = prompt.invoke({"context": context, "question": question}); | |
print(messages) | |
chat = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0.8) | |
response = chat.invoke(messages) | |
return response.content | |
# React to user input | |
if prompt := st.chat_input(""): | |
tokenized_prompt = tokenize(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
embedding = generate_embedding(tokenized_prompt) | |
results = collection.aggregate( | |
[ | |
{ | |
"$vectorSearch": { | |
"queryVector": embedding, | |
"path": "page_content_embedding", | |
"numCandidates": 5, | |
"limit": 5, | |
"index": "vector_index", | |
} | |
} | |
] | |
) | |
allIndx = [] | |
for document in results: | |
idx = document["index"] | |
allIndx.append(idx) | |
allIndx.append(idx + 1) | |
allIndx.append(idx + 2) | |
allIndx.append(idx + 3) | |
print(allIndx) | |
context = retriveByIndex(allIndx) | |
answer = generateAnswer(context, question=prompt) | |
with st.chat_message("assistant"): | |
st.markdown(answer, unsafe_allow_html=True) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": answer}) |