Charles Chan commited on
Commit
1e06430
·
1 Parent(s): a6c563b
Files changed (1) hide show
  1. app.py +10 -3
app.py CHANGED
@@ -6,7 +6,7 @@ from langchain_community.vectorstores import FAISS
6
  from datasets import load_dataset
7
 
8
  # Streamlit 界面
9
- st.title("假知识库问答系统")
10
 
11
  # 使用 假知识 数据集
12
  if "data_list" not in st.session_state:
@@ -17,6 +17,9 @@ if not st.session_state.data_list:
17
  try:
18
  with st.spinner("正在读取数据库..."):
19
  dataset = load_dataset("zeerd/fake_knowledge")
 
 
 
20
  data_list = []
21
  answer_list = []
22
  for example in dataset["train"]:
@@ -66,6 +69,10 @@ def get_answer(prompt):
66
  # 问答函数
67
  def answer_question(repo_id, temperature, max_length, question):
68
  # 初始化 Gemma 模型
 
 
 
 
69
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
70
  try:
71
  with st.spinner("正在初始化 Gemma 模型..."):
@@ -86,9 +93,9 @@ def answer_question(repo_id, temperature, max_length, question):
86
  st.success("答案生成完毕(基于模型自身)!")
87
  print("答案生成完毕(基于模型自身)!")
88
  with st.spinner("正在筛选本地数据集..."):
89
- # question_embedding = st.session_state.embeddings.embed_query(question)
90
  # question_embedding_str = " ".join(map(str, question_embedding))
91
- docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question)
92
 
93
  context_list = []
94
  for doc, score in docs_and_scores:
 
6
  from datasets import load_dataset
7
 
8
  # Streamlit 界面
9
+ st.title("外挂知识库问答系统")
10
 
11
  # 使用 假知识 数据集
12
  if "data_list" not in st.session_state:
 
17
  try:
18
  with st.spinner("正在读取数据库..."):
19
  dataset = load_dataset("zeerd/fake_knowledge")
20
+ # 输出前五条数据
21
+ print(dataset["train"][:5])
22
+
23
  data_list = []
24
  answer_list = []
25
  for example in dataset["train"]:
 
69
  # 问答函数
70
  def answer_question(repo_id, temperature, max_length, question):
71
  # 初始化 Gemma 模型
72
+ print('repo_id: ' + repo_id)
73
+ print('temperature: ' + str(temperature))
74
+ print('max_length: ' + str(max_length))
75
+
76
  if repo_id != st.session_state.repo_id or temperature != st.session_state.temperature or max_length != st.session_state.max_length:
77
  try:
78
  with st.spinner("正在初始化 Gemma 模型..."):
 
93
  st.success("答案生成完毕(基于模型自身)!")
94
  print("答案生成完毕(基于模型自身)!")
95
  with st.spinner("正在筛选本地数据集..."):
96
+ question_embedding = st.session_state.embeddings.embed_query(question)
97
  # question_embedding_str = " ".join(map(str, question_embedding))
98
+ docs_and_scores = st.session_state.db.similarity_search_with_relevance_scores(question, 8, question_embedding)
99
 
100
  context_list = []
101
  for doc, score in docs_and_scores: