Spaces:
Runtime error
Runtime error
init project
Browse files- app.py +100 -0
- requirements.txt +14 -0
- src/__pycache__/indexing.cpython-311.pyc +0 -0
- src/app.py +125 -0
- src/chat.py +43 -0
- src/indexing.py +72 -0
- src/services/__pycache__/generate_embedding.cpython-311.pyc +0 -0
- src/services/generate_embedding.py +9 -0
- src/services/read_pdf.py +72 -0
- src/services/sentence-embedding.py +9 -0
- src/test.py +29 -0
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pyvi.ViTokenizer import tokenize
|
3 |
+
from src.services.generate_embedding import generate_embedding
|
4 |
+
import pymongo
|
5 |
+
import time
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from langchain.prompts import ChatPromptTemplate
|
8 |
+
import os
|
9 |
+
|
10 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
11 |
+
|
12 |
+
# Connect DB
|
13 |
+
client = pymongo.MongoClient(
|
14 |
+
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG"
|
15 |
+
)
|
16 |
+
db = client.rag
|
17 |
+
collection = db.pdf
|
18 |
+
|
19 |
+
|
20 |
+
def stream_response(answer: str):
|
21 |
+
for word in answer.split(" "):
|
22 |
+
yield word + " "
|
23 |
+
time.sleep(0.03)
|
24 |
+
|
25 |
+
|
26 |
+
# Initialize chat history
|
27 |
+
if "messages" not in st.session_state:
|
28 |
+
st.session_state.messages = []
|
29 |
+
|
30 |
+
# Display chat messages from history on app rerun
|
31 |
+
for message in st.session_state.messages:
|
32 |
+
with st.chat_message(message["role"]):
|
33 |
+
st.markdown(message["content"], unsafe_allow_html=True)
|
34 |
+
|
35 |
+
|
36 |
+
def retriveByIndex(idxs):
|
37 |
+
docs = collection.find({"index": {"$in": idxs}})
|
38 |
+
content = ""
|
39 |
+
for doc in docs:
|
40 |
+
content = content + " " + doc["page_content"]
|
41 |
+
return content
|
42 |
+
|
43 |
+
|
44 |
+
def generateAnswer(context: str, question: str):
|
45 |
+
prompt = ChatPromptTemplate.from_messages(
|
46 |
+
[
|
47 |
+
(
|
48 |
+
"user","""Trả lời câu hỏi của người dùng dựa vào thông tin có trong thẻ <context> </context> được cho bên dưới. Nếu context không chứa những thông tin liên quan tới câu hỏi, thì đừng trả lời và chỉ trả lời là "Tôi không biết". <context> {context} </context> Câu hỏi: {question}""",
|
49 |
+
),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
messages = prompt.invoke({"context": context, "question": question});
|
53 |
+
print(messages)
|
54 |
+
chat = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0.8)
|
55 |
+
response = chat.invoke(messages)
|
56 |
+
return response.content
|
57 |
+
|
58 |
+
|
59 |
+
# React to user input
|
60 |
+
if prompt := st.chat_input(""):
|
61 |
+
tokenized_prompt = tokenize(prompt)
|
62 |
+
|
63 |
+
# Add user message to chat history
|
64 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
65 |
+
|
66 |
+
# Display user message in chat message container
|
67 |
+
with st.chat_message("user"):
|
68 |
+
st.markdown(prompt)
|
69 |
+
|
70 |
+
embedding = generate_embedding(tokenized_prompt)
|
71 |
+
results = collection.aggregate(
|
72 |
+
[
|
73 |
+
{
|
74 |
+
"$vectorSearch": {
|
75 |
+
"queryVector": embedding,
|
76 |
+
"path": "page_content_embedding",
|
77 |
+
"numCandidates": 5,
|
78 |
+
"limit": 5,
|
79 |
+
"index": "vector_index",
|
80 |
+
}
|
81 |
+
}
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
allIndx = []
|
86 |
+
for document in results:
|
87 |
+
idx = document["index"]
|
88 |
+
allIndx.append(idx)
|
89 |
+
allIndx.append(idx + 1)
|
90 |
+
allIndx.append(idx + 2)
|
91 |
+
allIndx.append(idx + 3)
|
92 |
+
print(allIndx)
|
93 |
+
|
94 |
+
context = retriveByIndex(allIndx)
|
95 |
+
answer = generateAnswer(context, question=prompt)
|
96 |
+
with st.chat_message("assistant"):
|
97 |
+
st.markdown(answer, unsafe_allow_html=True)
|
98 |
+
|
99 |
+
# Add assistant response to chat history
|
100 |
+
st.session_state.messages.append({"role": "assistant", "content": answer})
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pymongo
|
2 |
+
pandas
|
3 |
+
pyvi
|
4 |
+
transformers
|
5 |
+
streamlit
|
6 |
+
torch
|
7 |
+
pypdf
|
8 |
+
langchain_community
|
9 |
+
langchain
|
10 |
+
langchain_openai
|
11 |
+
faiss-cpu
|
12 |
+
chromadb
|
13 |
+
pysqlite3-binary
|
14 |
+
sentence-transformers
|
src/__pycache__/indexing.cpython-311.pyc
ADDED
Binary file (2.66 kB). View file
|
|
src/app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pyvi.ViTokenizer import tokenize
|
3 |
+
from services.generate_embedding import generate_embedding
|
4 |
+
import pymongo
|
5 |
+
import time
|
6 |
+
from indexing import indexData, SHEET_ID, SHEET_NAME
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain.prompts import ChatPromptTemplate
|
9 |
+
import os
|
10 |
+
|
11 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
12 |
+
|
13 |
+
# Connect DB
|
14 |
+
client = pymongo.MongoClient(
|
15 |
+
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG"
|
16 |
+
)
|
17 |
+
db = client.rag
|
18 |
+
collection = db.questionAndAnswers
|
19 |
+
|
20 |
+
with st.expander('Dataset'):
|
21 |
+
col1 , col2 = st.columns(2)
|
22 |
+
with col1:
|
23 |
+
st.markdown(
|
24 |
+
"""
|
25 |
+
<div style="display:flex; gap: 16px; align-items: center">
|
26 |
+
<a style="font-size: 14px"
|
27 |
+
href="https://docs.google.com/spreadsheets/d/1MKB6MHgL_lrPB1I69fj2VcVrgmSAMLVNZR1EwSyTSeA/edit#gid=0">Link
|
28 |
+
question & answers</a>
|
29 |
+
</div>
|
30 |
+
""",
|
31 |
+
unsafe_allow_html=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
with col2:
|
35 |
+
if st.button('Re-train'):
|
36 |
+
placeholder = st.empty()
|
37 |
+
placeholder.empty()
|
38 |
+
placeholder.write('Training ...')
|
39 |
+
indexData(SHEET_ID, SHEET_NAME)
|
40 |
+
placeholder.write('Completed')
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def generateAnswer(context: str, question: str):
|
45 |
+
prompt = ChatPromptTemplate.from_messages(
|
46 |
+
[
|
47 |
+
(
|
48 |
+
"user","""Trả lời câu hỏi của người dùng dựa vào thông tin có trong thẻ <context> </context> được cho bên dưới. Nếu context không chứa những thông tin liên quan tới câu hỏi, thì đừng trả lời và chỉ trả lời là "Tôi không biết". <context> {context} </context> Câu hỏi: {question}""",
|
49 |
+
),
|
50 |
+
]
|
51 |
+
)
|
52 |
+
messages = prompt.invoke({"context": context, "question": question});
|
53 |
+
print(messages)
|
54 |
+
chat = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0.8)
|
55 |
+
response = chat.invoke(messages)
|
56 |
+
return response.content
|
57 |
+
|
58 |
+
|
59 |
+
def stream_response(answer: str):
|
60 |
+
for word in answer.split(" "):
|
61 |
+
yield word + " "
|
62 |
+
time.sleep(0.03)
|
63 |
+
|
64 |
+
|
65 |
+
# Initialize chat history
|
66 |
+
if "messages" not in st.session_state:
|
67 |
+
st.session_state.messages = []
|
68 |
+
|
69 |
+
# Display chat messages from history on app rerun
|
70 |
+
for message in st.session_state.messages:
|
71 |
+
with st.chat_message(message["role"]):
|
72 |
+
st.markdown(message["content"], unsafe_allow_html=True)
|
73 |
+
|
74 |
+
# React to user input
|
75 |
+
if prompt := st.chat_input(""):
|
76 |
+
tokenized_prompt = tokenize(prompt)
|
77 |
+
|
78 |
+
# Add user message to chat history
|
79 |
+
st.session_state.messages.append({"role": "user", "content": tokenized_prompt})
|
80 |
+
|
81 |
+
# Display user message in chat message container
|
82 |
+
with st.chat_message("user"):
|
83 |
+
st.markdown(tokenized_prompt)
|
84 |
+
|
85 |
+
embedding = generate_embedding(tokenized_prompt)
|
86 |
+
results = collection.aggregate(
|
87 |
+
[
|
88 |
+
{
|
89 |
+
"$vectorSearch": {
|
90 |
+
"queryVector": embedding,
|
91 |
+
"path": "question_embedding",
|
92 |
+
"numCandidates": 10,
|
93 |
+
"limit": 10,
|
94 |
+
"index": "vector_index",
|
95 |
+
}
|
96 |
+
}
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
posibleQuestions = ""
|
101 |
+
context = ""
|
102 |
+
question = ""
|
103 |
+
index = 0
|
104 |
+
for document in results:
|
105 |
+
posibleQuestions = posibleQuestions + f"<li>{document['question']}</li>"
|
106 |
+
context =context + "\n\n" + document['question'] + ": " + document['answer']
|
107 |
+
if index == 0:
|
108 |
+
question = document["question"]
|
109 |
+
index = index + 1
|
110 |
+
posibleQuestions = f"""<ol> <p style="font-weight: 600">Câu hỏi liên quan: </p> {posibleQuestions}</ol>"""
|
111 |
+
|
112 |
+
answer = generateAnswer(context, prompt);
|
113 |
+
response = f"""<p>{answer}</p>
|
114 |
+
{posibleQuestions}
|
115 |
+
"""
|
116 |
+
|
117 |
+
# Display assistant response in chat message container
|
118 |
+
with st.chat_message("assistant"):
|
119 |
+
st.markdown(response, unsafe_allow_html=True)
|
120 |
+
# st.markdown(f"""<p style="font-weight: 600">Question: {question}</p>""", unsafe_allow_html=True)
|
121 |
+
# st.write_stream(stream_response(answer))
|
122 |
+
# st.markdown(posibleQuestions, unsafe_allow_html=True)
|
123 |
+
|
124 |
+
# Add assistant response to chat history
|
125 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
src/chat.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
# Streamed response emulator
|
7 |
+
def response_generator():
|
8 |
+
response = random.choice(
|
9 |
+
[
|
10 |
+
"Hello there! How can I assist you today?",
|
11 |
+
"Hi, human! Is there anything I can help you with?",
|
12 |
+
"Do you need help?",
|
13 |
+
]
|
14 |
+
)
|
15 |
+
for word in response.split():
|
16 |
+
yield word + " "
|
17 |
+
time.sleep(0.05)
|
18 |
+
|
19 |
+
|
20 |
+
st.title("Simple chat")
|
21 |
+
|
22 |
+
# Initialize chat history
|
23 |
+
if "messages" not in st.session_state:
|
24 |
+
st.session_state.messages = []
|
25 |
+
|
26 |
+
# Display chat messages from history on app rerun
|
27 |
+
for message in st.session_state.messages:
|
28 |
+
with st.chat_message(message["role"]):
|
29 |
+
st.markdown(message["content"])
|
30 |
+
|
31 |
+
# Accept user input
|
32 |
+
if prompt := st.chat_input("What is up?"):
|
33 |
+
# Add user message to chat history
|
34 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
35 |
+
# Display user message in chat message container
|
36 |
+
with st.chat_message("user"):
|
37 |
+
st.markdown(prompt)
|
38 |
+
|
39 |
+
# Display assistant response in chat message container
|
40 |
+
with st.chat_message("assistant"):
|
41 |
+
response = st.write_stream(response_generator())
|
42 |
+
# Add assistant response to chat history
|
43 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
src/indexing.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
|
3 |
+
from services.generate_embedding import generate_embedding
|
4 |
+
from pyvi.ViTokenizer import tokenize
|
5 |
+
import pymongo
|
6 |
+
|
7 |
+
SHEET_ID = "1MKB6MHgL_lrPB1I69fj2VcVrgmSAMLVNZR1EwSyTSeA"
|
8 |
+
SHEET_NAME = "Q&A"
|
9 |
+
|
10 |
+
# Connect DB
|
11 |
+
client = pymongo.MongoClient(
|
12 |
+
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG"
|
13 |
+
)
|
14 |
+
|
15 |
+
db = client.rag
|
16 |
+
collection = db.questionAndAnswers
|
17 |
+
|
18 |
+
|
19 |
+
def insertQuestionAndAnswers(questionAndAnswers):
|
20 |
+
return collection.insert_many(questionAndAnswers)
|
21 |
+
|
22 |
+
def deleteByUserId(user_id: str):
|
23 |
+
return collection.delete_many({'user_id': user_id})
|
24 |
+
|
25 |
+
def readDataFromGoogleSheet(sheet_id: str, sheet_name: str):
|
26 |
+
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}"
|
27 |
+
df = pd.read_csv(url)
|
28 |
+
items = []
|
29 |
+
for index, row in df.iterrows():
|
30 |
+
items.append(
|
31 |
+
{
|
32 |
+
"question": row["Question"],
|
33 |
+
"answer": row["Answer"],
|
34 |
+
}
|
35 |
+
)
|
36 |
+
print(f'read from google sheet {df.size} items')
|
37 |
+
return items
|
38 |
+
|
39 |
+
|
40 |
+
def indexData(sheet_id: str, sheet_name: str):
|
41 |
+
items = readDataFromGoogleSheet(sheet_id, sheet_name)
|
42 |
+
questionAndAnswers = []
|
43 |
+
for item in items:
|
44 |
+
tokenized_question = tokenize(item["question"])
|
45 |
+
questionAndAnswer = {
|
46 |
+
"question": tokenized_question,
|
47 |
+
"answer": item["answer"],
|
48 |
+
"question_embedding": generate_embedding(tokenized_question),
|
49 |
+
"user_id": sheet_id,
|
50 |
+
}
|
51 |
+
questionAndAnswers.append(questionAndAnswer)
|
52 |
+
deleteByUserId(sheet_id)
|
53 |
+
insertQuestionAndAnswers(questionAndAnswers)
|
54 |
+
# for index, article in enumerate(data):
|
55 |
+
# if(index< 6580):
|
56 |
+
# continue;
|
57 |
+
|
58 |
+
# if(len(str(article['title'])) == 0 or len(str(article['description'])) == 0 or len(str(article['link'])) == 0 ):
|
59 |
+
# continue
|
60 |
+
|
61 |
+
# tokenized_title = tokenize(article['title'])
|
62 |
+
# tokenized_description = tokenize(article['description'])
|
63 |
+
# article = {
|
64 |
+
# 'title': tokenized_title,
|
65 |
+
# 'description': tokenized_description,
|
66 |
+
# 'link': article['link'],
|
67 |
+
# # 'title_embedding': generate_embedding(tokenized_title),
|
68 |
+
# 'title_embedding': [],
|
69 |
+
# 'description_embedding': generate_embedding(tokenized_title + ": " + tokenized_description),
|
70 |
+
# }
|
71 |
+
# print(f"processed {index}/{len(articles)}")
|
72 |
+
# save_db(article)
|
src/services/__pycache__/generate_embedding.cpython-311.pyc
ADDED
Binary file (1.09 kB). View file
|
|
src/services/generate_embedding.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModel, AutoTokenizer
|
2 |
+
|
3 |
+
PhobertTokenizer = AutoTokenizer.from_pretrained("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
4 |
+
model = AutoModel.from_pretrained("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base")
|
5 |
+
|
6 |
+
def generate_embedding(sentence: str):
|
7 |
+
inputs = PhobertTokenizer(sentence, padding=True, truncation=True, return_tensors="pt")
|
8 |
+
embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
|
9 |
+
return embeddings[0].detach().numpy().tolist()
|
src/services/read_pdf.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pyvi.ViTokenizer import tokenize
|
3 |
+
from langchain_community.document_loaders import PyPDFLoader
|
4 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
5 |
+
import pymongo
|
6 |
+
from generate_embedding import generate_embedding
|
7 |
+
|
8 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
9 |
+
|
10 |
+
# Connect DB
|
11 |
+
client = pymongo.MongoClient(
|
12 |
+
"mongodb+srv://rag:[email protected]/?retryWrites=true&w=majority&appName=RAG"
|
13 |
+
)
|
14 |
+
|
15 |
+
db = client.rag
|
16 |
+
collection = db.pdf
|
17 |
+
|
18 |
+
|
19 |
+
def insertData(chunk):
|
20 |
+
return collection.insert_many(chunk)
|
21 |
+
|
22 |
+
|
23 |
+
def deleteByUserId(user_id: str):
|
24 |
+
return collection.delete_many({"user_id": user_id})
|
25 |
+
|
26 |
+
|
27 |
+
def readFromPDF():
|
28 |
+
# load PDF
|
29 |
+
loader = PyPDFLoader("data/cds.pdf")
|
30 |
+
pages = loader.load_and_split()
|
31 |
+
pages = list(filter(lambda page: page.metadata['page'] >= 10, pages))
|
32 |
+
|
33 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=768, chunk_overlap=200)
|
34 |
+
chunks = text_splitter.split_documents(pages)
|
35 |
+
items = []
|
36 |
+
for index, chunk in enumerate(chunks):
|
37 |
+
print(index)
|
38 |
+
items.append({"page_content": chunk.page_content, "index": index})
|
39 |
+
return items
|
40 |
+
|
41 |
+
|
42 |
+
def indexData(user_id: str):
|
43 |
+
items = readFromPDF()
|
44 |
+
contents = []
|
45 |
+
for item in items:
|
46 |
+
tokenized_page_content = tokenize(item["page_content"])
|
47 |
+
content = {
|
48 |
+
"page_content": item["page_content"],
|
49 |
+
"page_content_embedding": generate_embedding(tokenized_page_content),
|
50 |
+
"user_id": user_id,
|
51 |
+
"index": item["index"],
|
52 |
+
}
|
53 |
+
contents.append(content)
|
54 |
+
deleteByUserId(user_id)
|
55 |
+
insertData(contents)
|
56 |
+
|
57 |
+
|
58 |
+
indexData("cds.pdf")
|
59 |
+
|
60 |
+
# prompt = hub.pull("rlm/rag-prompt")
|
61 |
+
# llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
62 |
+
|
63 |
+
# def format_docs(docs):
|
64 |
+
# return "\n\n".join(doc.page_content for doc in docs)
|
65 |
+
|
66 |
+
|
67 |
+
# rag_chain = (
|
68 |
+
# {"context": retriever | format_docs, "question": RunnablePassthrough()}
|
69 |
+
# | prompt
|
70 |
+
# | llm
|
71 |
+
# | StrOutputParser()
|
72 |
+
# )
|
src/services/sentence-embedding.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.embeddings.sentence_transformer import (
|
2 |
+
SentenceTransformerEmbeddings
|
3 |
+
)
|
4 |
+
|
5 |
+
model = SentenceTransformerEmbeddings(model_name="vinai/phobert-base-v2")
|
6 |
+
|
7 |
+
query = 'This framework generates embeddings for each input sentence'
|
8 |
+
sentence_embeddings = model.embed_query(query)
|
9 |
+
print(len(sentence_embeddings))
|
src/test.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf8
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
|
5 |
+
model_path = "vinai/PhoGPT-4B"
|
6 |
+
|
7 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
+
print(device)
|
9 |
+
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
12 |
+
|
13 |
+
inputs = tokenizer('### Câu hỏi: Viết bài văn nghị luận xã hội về an toàn giao thông \n### Trả lời:', return_tensors='pt').to(device)
|
14 |
+
print(inputs)
|
15 |
+
|
16 |
+
outputs = model.generate(
|
17 |
+
inputs=inputs["input_ids"].to(device),
|
18 |
+
attention_mask=inputs["attention_mask"].to(device),
|
19 |
+
do_sample=True,
|
20 |
+
temperature=1.0,
|
21 |
+
top_k=50,
|
22 |
+
top_p=0.9,
|
23 |
+
max_new_tokens=1024,
|
24 |
+
eos_token_id=tokenizer.eos_token_id,
|
25 |
+
pad_token_id=tokenizer.pad_token_id
|
26 |
+
)
|
27 |
+
|
28 |
+
response = tokenizer.decode(outputs[0])
|
29 |
+
print(response)
|