Charles Chan commited on
Commit
c2aa18b
·
1 Parent(s): 131ded0
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -8,26 +8,29 @@ from opencc import OpenCC
8
 
9
  # 使用 進擊的巨人 数据集
10
  # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
11
- if "dataset_loaded" not in st.session_state:
12
- st.session_state.dataset_loaded = False
13
  st.session_state.data_list = []
14
  st.session_state.answer_list = []
15
- if not st.session_state.dataset_loaded:
 
16
  try:
17
  with st.spinner("正在读取数据库..."):
18
- converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
19
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
 
 
20
  for example in dataset["train"]:
21
- converted_answer = converter.convert(example["Answer"])
22
- converted_question = converter.convert(example["Question"])
23
- st.session_state.answer_list.append(converted_answer)
24
- st.session_state.data_list.append({"Question": converted_question, "Answer": converted_answer})
 
 
25
  st.success("数据库读取完成!")
26
  print("数据库读取完成!")
27
  except Exception as e:
28
  st.error(f"读取数据集失败:{e}")
29
  st.stop()
30
- st.session_state.dataset_loaded = True
31
 
32
  # 构建向量数据库 (如果需要,仅构建一次)
33
  if "vector_created" not in st.session_state:
@@ -117,9 +120,9 @@ with col3:
117
  random_index = random.randint(0, dataset_size - 1)
118
  # 读取随机问题
119
  random_question = st.session_state.data_list[random_index]["Question"]
120
- random_question = converter.convert(random_question)
121
  origin_answer = st.session_state.data_list[random_index]["Answer"]
122
- origin_answer = converter.convert(origin_answer)
123
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
124
  print('origin_answer: ' + origin_answer)
125
 
 
8
 
9
  # 使用 進擊的巨人 数据集
10
  # 原数据集是是繁体中文,为了调试方便,将其转换成简体中文之后使用
11
+ if "data_list" not in st.session_state:
 
12
  st.session_state.data_list = []
13
  st.session_state.answer_list = []
14
+
15
+ if not st.session_state.data_list:
16
  try:
17
  with st.spinner("正在读取数据库..."):
18
+ st.session_state.converter = OpenCC('tw2s') # 'tw2s.json' 表示繁体中文到简体中文的转换
19
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
20
+ data_list = []
21
+ answer_list = []
22
  for example in dataset["train"]:
23
+ converted_answer = st.session_state.converter.convert(example["Answer"])
24
+ converted_question = st.session_state.converter.convert(example["Question"])
25
+ answer_list.append(converted_answer)
26
+ data_list.append({"Question": converted_question, "Answer": converted_answer})
27
+ st.session_state.answer_list = answer_list
28
+ st.session_state.data_list = data_list
29
  st.success("数据库读取完成!")
30
  print("数据库读取完成!")
31
  except Exception as e:
32
  st.error(f"读取数据集失败:{e}")
33
  st.stop()
 
34
 
35
  # 构建向量数据库 (如果需要,仅构建一次)
36
  if "vector_created" not in st.session_state:
 
120
  random_index = random.randint(0, dataset_size - 1)
121
  # 读取随机问题
122
  random_question = st.session_state.data_list[random_index]["Question"]
123
+ random_question = st.session_state.converter.convert(random_question)
124
  origin_answer = st.session_state.data_list[random_index]["Answer"]
125
+ origin_answer = st.session_state.converter.convert(origin_answer)
126
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
127
  print('origin_answer: ' + origin_answer)
128