Charles Chan
commited on
Commit
·
8db8541
1
Parent(s):
b9bde1a
coding
Browse files
app.py
CHANGED
@@ -88,11 +88,13 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
88 |
with st.spinner("正在筛选本地数据集..."):
|
89 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
90 |
question_embedding_str = " ".join(map(str, question_embedding))
|
91 |
-
# print('question_embedding: ' + question_embedding_str)
|
92 |
docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str)
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
|
98 |
print('prompt: ' + prompt)
|
@@ -100,10 +102,10 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
100 |
st.success("本地数据集筛选完成!")
|
101 |
print("本地数据集筛选完成!")
|
102 |
|
103 |
-
with st.spinner("
|
104 |
answer = get_answer(prompt)
|
105 |
-
st.success("
|
106 |
-
print("
|
107 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
108 |
except Exception as e:
|
109 |
st.error(f"问答过程出错:{e}")
|
@@ -149,7 +151,7 @@ with col3:
|
|
149 |
generate_answer(gemma, float(temperature), int(max_length), random_question)
|
150 |
|
151 |
with col4:
|
152 |
-
question = st.text_area("请输入问题", "
|
153 |
if st.button("提交输入的问题"):
|
154 |
if not question:
|
155 |
st.warning("请输入问题!")
|
|
|
88 |
with st.spinner("正在筛选本地数据集..."):
|
89 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
90 |
question_embedding_str = " ".join(map(str, question_embedding))
|
|
|
91 |
docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str)
|
92 |
|
93 |
+
context_list = []
|
94 |
+
for doc, score in docs_and_scores:
|
95 |
+
print(str(score) + ' : ' + doc.page_content)
|
96 |
+
context_list.append(doc.page_content)
|
97 |
+
context = "\n".join(context_list)
|
98 |
|
99 |
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
|
100 |
print('prompt: ' + prompt)
|
|
|
102 |
st.success("本地数据集筛选完成!")
|
103 |
print("本地数据集筛选完成!")
|
104 |
|
105 |
+
with st.spinner("正在生成答案(基于本地数据集)..."):
|
106 |
answer = get_answer(prompt)
|
107 |
+
st.success("答案生成完毕(基于本地数据集)!")
|
108 |
+
print("答案生成完毕(基于本地数据集)!")
|
109 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
110 |
except Exception as e:
|
111 |
st.error(f"问答过程出错:{e}")
|
|
|
151 |
generate_answer(gemma, float(temperature), int(max_length), random_question)
|
152 |
|
153 |
with col4:
|
154 |
+
question = st.text_area("请输入问题", "太阳为什么是绿色的?")
|
155 |
if st.button("提交输入的问题"):
|
156 |
if not question:
|
157 |
st.warning("请输入问题!")
|