Charles Chan commited on
Commit
742b788
·
1 Parent(s): 38e677d
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -58,6 +58,15 @@ if "temperature" not in st.session_state:
58
  st.session_state.temperature = ''
59
  if "max_length" not in st.session_state:
60
  st.session_state.max_length = ''
 
 
 
 
 
 
 
 
 
61
  def answer_question(repo_id, temperature, max_length, question):
62
  # 初始化 Gemma 模型
63
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
@@ -91,12 +100,9 @@ def answer_question(repo_id, temperature, max_length, question):
91
  print("本地数据集筛选完成!")
92
 
93
  with st.spinner("正在生成答案..."):
94
- answer = st.session_state.llm.invoke(prompt)
95
- # 去掉 prompt 的内容
96
- answer = answer.replace(prompt, "").strip()
97
- st.success("答案已经生成!")
98
- print("答案已经生成!")
99
- return {"prompt": prompt, "answer": answer}
100
  except Exception as e:
101
  st.error(f"问答过程出错:{e}")
102
  return {"prompt": "", "answer": "An error occurred during the answering process."}
@@ -117,6 +123,10 @@ def generate_answer(repo_id, temperature, max_length, question):
117
  result = answer_question(repo_id, float(temperature), int(max_length), question)
118
  print('prompt: ' + result["prompt"])
119
  print('answer: ' + result["answer"])
 
 
 
 
120
  st.write("参考文字:")
121
  st.markdown(result["prompt"].replace('\n', '<br/>'))
122
  st.write("生成答案:")
 
58
  st.session_state.temperature = ''
59
  if "max_length" not in st.session_state:
60
  st.session_state.max_length = ''
61
+
62
+ def get_answer(prompt):
63
+ answer = st.session_state.llm.invoke(prompt)
64
+ # 去掉 prompt 的内容
65
+ answer = answer.replace(prompt, "").strip()
66
+ st.success("答案已经生成!")
67
+ print("答案已经生成!")
68
+ return answer
69
+
70
  def answer_question(repo_id, temperature, max_length, question):
71
  # 初始化 Gemma 模型
72
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
 
100
  print("本地数据集筛选完成!")
101
 
102
  with st.spinner("正在生成答案..."):
103
+ pure_answer = get_answer(question)
104
+ answer = get_answer(prompt)
105
+ return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
 
 
 
106
  except Exception as e:
107
  st.error(f"问答过程出错:{e}")
108
  return {"prompt": "", "answer": "An error occurred during the answering process."}
 
123
  result = answer_question(repo_id, float(temperature), int(max_length), question)
124
  print('prompt: ' + result["prompt"])
125
  print('answer: ' + result["answer"])
126
+ print('pure_answer: ' + result["pure_answer"])
127
+ st.write("生成答案(无参考):")
128
+ st.write(result["pure_answer"])
129
+ st.divider()
130
  st.write("参考文字:")
131
  st.markdown(result["prompt"].replace('\n', '<br/>'))
132
  st.write("生成答案:")