Charles Chan
commited on
Commit
·
10b5e55
1
Parent(s):
4b8fc6b
coding
Browse files
app.py
CHANGED
@@ -51,7 +51,6 @@ if not st.session_state.vector_created:
|
|
51 |
st.stop()
|
52 |
st.session_state.vector_created = True
|
53 |
|
54 |
-
# 问答函数
|
55 |
if "repo_id" not in st.session_state:
|
56 |
st.session_state.repo_id = ''
|
57 |
if "temperature" not in st.session_state:
|
@@ -64,10 +63,9 @@ def get_answer(prompt):
|
|
64 |
# 去掉 prompt 的内容
|
65 |
answer = answer.replace(prompt, "").strip()
|
66 |
print(answer)
|
67 |
-
st.success("答案已经生成!")
|
68 |
-
print("答案已经生成!")
|
69 |
return answer
|
70 |
|
|
|
71 |
def answer_question(repo_id, temperature, max_length, question):
|
72 |
# 初始化 Gemma 模型
|
73 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
@@ -85,6 +83,10 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
85 |
|
86 |
# 获取答案
|
87 |
try:
|
|
|
|
|
|
|
|
|
88 |
with st.spinner("正在筛选本地数据集..."):
|
89 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
90 |
question_embedding_str = " ".join(map(str, question_embedding))
|
@@ -101,8 +103,9 @@ def answer_question(repo_id, temperature, max_length, question):
|
|
101 |
print("本地数据集筛选完成!")
|
102 |
|
103 |
with st.spinner("正在生成答案..."):
|
104 |
-
pure_answer = get_answer(question)
|
105 |
answer = get_answer(prompt)
|
|
|
|
|
106 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
107 |
except Exception as e:
|
108 |
st.error(f"问答过程出错:{e}")
|
|
|
51 |
st.stop()
|
52 |
st.session_state.vector_created = True
|
53 |
|
|
|
54 |
if "repo_id" not in st.session_state:
|
55 |
st.session_state.repo_id = ''
|
56 |
if "temperature" not in st.session_state:
|
|
|
63 |
# 去掉 prompt 的内容
|
64 |
answer = answer.replace(prompt, "").strip()
|
65 |
print(answer)
|
|
|
|
|
66 |
return answer
|
67 |
|
68 |
+
# 问答函数
|
69 |
def answer_question(repo_id, temperature, max_length, question):
|
70 |
# 初始化 Gemma 模型
|
71 |
if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
|
|
|
83 |
|
84 |
# 获取答案
|
85 |
try:
|
86 |
+
with st.spinner("正在生成答案(基于模型自身)..."):
|
87 |
+
pure_answer = get_answer(question)
|
88 |
+
st.success("答案生成完毕(基于模型自身)!")
|
89 |
+
print("答案生成完毕(基于模型自身)!")
|
90 |
with st.spinner("正在筛选本地数据集..."):
|
91 |
question_embedding = st.session_state.embeddings.embed_query(question)
|
92 |
question_embedding_str = " ".join(map(str, question_embedding))
|
|
|
103 |
print("本地数据集筛选完成!")
|
104 |
|
105 |
with st.spinner("正在生成答案..."):
|
|
|
106 |
answer = get_answer(prompt)
|
107 |
+
st.success("答案生成完毕!")
|
108 |
+
print("答案生成完毕!")
|
109 |
return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
|
110 |
except Exception as e:
|
111 |
st.error(f"问答过程出错:{e}")
|