Charles Chan commited on
Commit
8db8541
·
1 Parent(s): b9bde1a
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -88,11 +88,13 @@ def answer_question(repo_id, temperature, max_length, question):
88
  with st.spinner("正在筛选本地数据集..."):
89
  question_embedding = st.session_state.embeddings.embed_query(question)
90
  question_embedding_str = " ".join(map(str, question_embedding))
91
- # print('question_embedding: ' + question_embedding_str)
92
  docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str)
93
 
94
- context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
95
- print('context: ' + context)
 
 
 
96
 
97
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
98
  print('prompt: ' + prompt)
@@ -100,10 +102,10 @@ def answer_question(repo_id, temperature, max_length, question):
100
  st.success("本地数据集筛选完成!")
101
  print("本地数据集筛选完成!")
102
 
103
- with st.spinner("正在生成答案..."):
104
  answer = get_answer(prompt)
105
- st.success("答案生成完毕!")
106
- print("答案生成完毕!")
107
  return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
108
  except Exception as e:
109
  st.error(f"问答过程出错:{e}")
@@ -149,7 +151,7 @@ with col3:
149
  generate_answer(gemma, float(temperature), int(max_length), random_question)
150
 
151
  with col4:
152
- question = st.text_area("请输入问题", "谁是潜水员?")
153
  if st.button("提交输入的问题"):
154
  if not question:
155
  st.warning("请输入问题!")
 
88
  with st.spinner("正在筛选本地数据集..."):
89
  question_embedding = st.session_state.embeddings.embed_query(question)
90
  question_embedding_str = " ".join(map(str, question_embedding))
 
91
  docs_and_scores = st.session_state.db.similarity_search_with_score(question_embedding_str)
92
 
93
+ context_list = []
94
+ for doc, score in docs_and_scores:
95
+ print(str(score) + ' : ' + doc.page_content)
96
+ context_list.append(doc.page_content)
97
+ context = "\n".join(context_list)
98
 
99
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
100
  print('prompt: ' + prompt)
 
102
  st.success("本地数据集筛选完成!")
103
  print("本地数据集筛选完成!")
104
 
105
+ with st.spinner("正在生成答案(基于本地数据集)..."):
106
  answer = get_answer(prompt)
107
+ st.success("答案生成完毕(基于本地数据集)!")
108
+ print("答案生成完毕(基于本地数据集)!")
109
  return {"prompt": prompt, "answer": answer, "pure_answer": pure_answer}
110
  except Exception as e:
111
  st.error(f"问答过程出错:{e}")
 
151
  generate_answer(gemma, float(temperature), int(max_length), random_question)
152
 
153
  with col4:
154
+ question = st.text_area("请输入问题", "太阳为什么是绿色的?")
155
  if st.button("提交输入的问题"):
156
  if not question:
157
  st.warning("请输入问题!")