Charles Chan commited on
Commit
10b5e55
·
1 Parent(s): 4b8fc6b
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -51,7 +51,6 @@ if not st.session_state.vector_created:
51
  st.stop()
52
  st.session_state.vector_created = True
53
 
54
- # 问答函数
55
  if "repo_id" not in st.session_state:
56
  st.session_state.repo_id = ''
57
  if "temperature" not in st.session_state:
@@ -64,10 +63,9 @@ def get_answer(prompt):
64
  # 去掉 prompt 的内容
65
  answer = answer.replace(prompt, "").strip()
66
  print(answer)
67
- st.success("答案已经生成!")
68
- print("答案已经生成!")
69
  return answer
70
 
 
71
  def answer_question(repo_id, temperature, max_length, question):
72
  # 初始化 Gemma 模型
73
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
@@ -85,6 +83,10 @@ def answer_question(repo_id, temperature, max_length, question):
85
 
86
  # 获取答案
87
  try:
 
 
 
 
88
  with st.spinner("正在筛选本地数据集..."):
89
  question_embedding = st.session_state.embeddings.embed_query(question)
90
  question_embedding_str = " ".join(map(str, question_embedding))
@@ -101,8 +103,9 @@ def answer_question(repo_id, temperature, max_length, question):
101
  print("本地数据集筛选完成!")
102
 
103
  with st.spinner("正在生成答案..."):
104
- pure_answer = get_answer(question)
105
  answer = get_answer(prompt)
 
 
106
  return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
107
  except Exception as e:
108
  st.error(f"问答过程出错:{e}")
 
51
  st.stop()
52
  st.session_state.vector_created = True
53
 
 
54
  if "repo_id" not in st.session_state:
55
  st.session_state.repo_id = ''
56
  if "temperature" not in st.session_state:
 
63
  # 去掉 prompt 的内容
64
  answer = answer.replace(prompt, "").strip()
65
  print(answer)
 
 
66
  return answer
67
 
68
+ # 问答函数
69
  def answer_question(repo_id, temperature, max_length, question):
70
  # 初始化 Gemma 模型
71
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
 
83
 
84
  # 获取答案
85
  try:
86
+ with st.spinner("正在生成答案(基于模型自身)..."):
87
+ pure_answer = get_answer(question)
88
+ st.success("答案生成完毕(基于模型自身)!")
89
+ print("答案生成完毕(基于模型自身)!")
90
  with st.spinner("正在筛选本地数据集..."):
91
  question_embedding = st.session_state.embeddings.embed_query(question)
92
  question_embedding_str = " ".join(map(str, question_embedding))
 
103
  print("本地数据集筛选完成!")
104
 
105
  with st.spinner("正在生成答案..."):
 
106
  answer = get_answer(prompt)
107
+ st.success("答案生成完毕!")
108
+ print("答案生成完毕!")
109
  return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
110
  except Exception as e:
111
  st.error(f"问答过程出错:{e}")