Charles Chan commited on
Commit
de611e2
·
1 Parent(s): 1e21aa9
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -10,13 +10,20 @@ from opencc import OpenCC
10
  # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
11
  if "dataset_loaded" not in st.session_state:
12
  st.session_state.dataset_loaded = False
 
 
13
  if not st.session_state.dataset_loaded:
14
  try:
15
  with st.spinner("正在读取数据库..."):
16
  converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
17
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
18
- answer_list = [converter.convert(example["Answer"]) for example in dataset["train"]]
 
 
 
 
19
  st.success("数据库读取完成!")
 
20
  except Exception as e:
21
  st.error(f"读取数据集失败:{e}")
22
  st.stop()
@@ -29,8 +36,9 @@ if not st.session_state.vector_created:
29
  try:
30
  with st.spinner("正在构建向量数据库..."):
31
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
32
- db = FAISS.from_texts(answer_list, embeddings)
33
  st.success("向量数据库构建完成!")
 
34
  except Exception as e:
35
  st.error(f"向量数据库构建失败:{e}")
36
  st.stop()
@@ -50,6 +58,7 @@ def answer_question(repo_id, temperature, max_length, question):
50
  with st.spinner("正在初始化 Gemma 模型..."):
51
  llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
52
  st.success("Gemma 模型初始化完成!")
 
53
  st.session_state.repo_id = repo_id
54
  st.session_state.temperature = temperature
55
  st.session_state.max_length = max_length
@@ -72,12 +81,14 @@ def answer_question(repo_id, temperature, max_length, question):
72
  print('prompt: ' + prompt)
73
 
74
  st.success("本地数据集筛选完成!")
 
75
 
76
  with st.spinner("正在生成答案..."):
77
  answer = llm.invoke(prompt)
78
  # 去掉 prompt 的内容
79
  answer = answer.replace(prompt, "").strip()
80
  st.success("答案已经生成!")
 
81
  return {"prompt": prompt, "answer": answer}
82
  except Exception as e:
83
  st.error(f"问答过程出错:{e}")
@@ -98,12 +109,12 @@ st.divider()
98
  col3, col4 = st.columns(2)
99
  with col3:
100
  if st.button("使用原数据集中的随机问题"):
101
- dataset_size = len(dataset["train"])
102
  random_index = random.randint(0, dataset_size - 1)
103
  # 读取随机问题
104
- random_question = dataset["train"][random_index]["Question"]
105
  random_question = converter.convert(random_question)
106
- origin_answer = dataset["train"][random_index]["Answer"]
107
  origin_answer = converter.convert(origin_answer)
108
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
109
  print('origin_answer: ' + origin_answer)
 
10
  # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
11
  if "dataset_loaded" not in st.session_state:
12
  st.session_state.dataset_loaded = False
13
+ st.session_state.data_list = []
14
+ st.session_state.answer_list = []
15
  if not st.session_state.dataset_loaded:
16
  try:
17
  with st.spinner("正在读取数据库..."):
18
  converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
19
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
20
+ for example in dataset["train"]:
21
+ converted_answer = converter.convert(example["Answer"])
22
+ converted_question = converter.convert(example["Question"])
23
+ st.session_state.answer_list.append(converted_answer)
24
+ st.session_state.data_list.append({"Question": converted_question, "Answer": converted_answer})
25
  st.success("数据库读取完成!")
26
+ print("数据库读取完成!")
27
  except Exception as e:
28
  st.error(f"读取数据集失败:{e}")
29
  st.stop()
 
36
  try:
37
  with st.spinner("正在构建向量数据库..."):
38
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
39
+ db = FAISS.from_texts(st.session_state.answer_list, embeddings)
40
  st.success("向量数据库构建完成!")
41
+ print("向量数据库构建完成!")
42
  except Exception as e:
43
  st.error(f"向量数据库构建失败:{e}")
44
  st.stop()
 
58
  with st.spinner("正在初始化 Gemma 模型..."):
59
  llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
60
  st.success("Gemma 模型初始化完成!")
61
+ print("Gemma 模型初始化完成!")
62
  st.session_state.repo_id = repo_id
63
  st.session_state.temperature = temperature
64
  st.session_state.max_length = max_length
 
81
  print('prompt: ' + prompt)
82
 
83
  st.success("本地数据集筛选完成!")
84
+ print("本地数据集筛选完成!")
85
 
86
  with st.spinner("正在生成答案..."):
87
  answer = llm.invoke(prompt)
88
  # 去掉 prompt 的内容
89
  answer = answer.replace(prompt, "").strip()
90
  st.success("答案已经生成!")
91
+ print("答案已经生成!")
92
  return {"prompt": prompt, "answer": answer}
93
  except Exception as e:
94
  st.error(f"问答过程出错:{e}")
 
109
  col3, col4 = st.columns(2)
110
  with col3:
111
  if st.button("使用原数据集中的随机问题"):
112
+ dataset_size = len(st.session_state.data_list)
113
  random_index = random.randint(0, dataset_size - 1)
114
  # 读取随机问题
115
+ random_question = st.session_state.data_list[random_index]["Question"]
116
  random_question = converter.convert(random_question)
117
+ origin_answer = st.session_state.data_list[random_index]["Answer"]
118
  origin_answer = converter.convert(origin_answer)
119
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
120
  print('origin_answer: ' + origin_answer)