Charles Chan
commited on
Commit
·
1e06430
1
Parent(s):
a6c563b
coding
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from langchain_community.vectorstores import FAISS
|
|
6 |
from datasets import load_dataset
|
7 |
|
8 |
# Streamlit 界面
|
9 |
-
st.title("
|
10 |
|
11 |
# 使用 假知识 数据集
|
12 |
if "data_list" not in st.session_state:
|
@@ -17,6 +17,9 @@ if not st.session_state.data_list:
|
|
17 |
try:
|
18 |
with st.spinner("正在读取数据库..."):
|
19 |
dataset = load_dataset("zeerd/fake_knowledge")
|
|
|
|
|
|
|
20 |
data_list = []
|
21 |
answer_list = []
|
22 |
for example in dataset["train"]:
|
@@ -66,6 +69,10 @@ def get_answer(prompt):
|
|
66 |
# 问答函数
|
67 |
def answer_question(repo_id, temperature, max_length, question):
|
68 |
# 初始化 Gemma 模型
|
|
|
|
|
|
|
|
|
69 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
70 |
try:
|
71 |
with st.spinner("正在初始化 Gemma 模型..."):
|
@@ -86,9 +93,9 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
86 |
st.success("答案生成完毕(基于模型自身)!")
|
87 |
print("答案生成完毕(基于模型自身)!")
|
88 |
with st.spinner("正在筛选本地数据集..."):
|
89 |
-
|
90 |
# question_embedding_str = " ".join(map(str, question_embedding))
|
91 |
-
docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question)
|
92 |
|
93 |
context_list = []
|
94 |
for doc, score in docs_and_scores:
|
|
|
6 |
from datasets import load_dataset
|
7 |
|
8 |
# Streamlit 界面
|
9 |
+
st.title("外挂知识库问答系统")
|
10 |
|
11 |
# 使用 假知识 数据集
|
12 |
if "data_list" not in st.session_state:
|
|
|
17 |
try:
|
18 |
with st.spinner("正在读取数据库..."):
|
19 |
dataset = load_dataset("zeerd/fake_knowledge")
|
20 |
+
# 输出前五条数据
|
21 |
+
print(dataset["train"][:5])
|
22 |
+
|
23 |
data_list = []
|
24 |
answer_list = []
|
25 |
for example in dataset["train"]:
|
|
|
69 |
# 问答函数
|
70 |
def answer_question(repo_id, temperature, max_length, question):
|
71 |
# 初始化 Gemma 模型
|
72 |
+
print('repo_id: ' + repo_id)
|
73 |
+
print('temperature: ' + str(temperature))
|
74 |
+
print('max_length: ' + str(max_length))
|
75 |
+
|
76 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
77 |
try:
|
78 |
with st.spinner("正在初始化 Gemma 模型..."):
|
|
|
93 |
st.success("答案生成完毕(基于模型自身)!")
|
94 |
print("答案生成完毕(基于模型自身)!")
|
95 |
with st.spinner("正在筛选本地数据集..."):
|
96 |
+
question_embedding = st.session_state.embeddings.embed_query(question)
|
97 |
# question_embedding_str = " ".join(map(str, question_embedding))
|
98 |
+
docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question, 8, question_embedding)
|
99 |
|
100 |
context_list = []
|
101 |
for doc, score in docs_and_scores:
|