Charles Chan commited on
Commit
86b4310
·
1 Parent(s): 908d31d
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -7,8 +7,9 @@ import random
7
 
8
  # 使用 進擊的巨人 数据集
9
  try:
 
10
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
11
- answer_list = [example["Answer"] for example in dataset["train"]]
12
 
13
  except Exception as e:
14
  st.error(f"读取数据集失败:{e}")
@@ -67,14 +68,18 @@ with col2:
67
  temperature = st.number_input("temperature", value=1.0)
68
  max_length = st.number_input("max_length", value=1024)
69
 
 
 
70
  col3, col4 = st.columns(2)
71
  with col3:
72
- if st.button("随机"):
73
  dataset_size = len(dataset["train"])
74
  random_index = random.randint(0, dataset_size - 1)
75
  # 读取随机问题
76
  random_question = dataset["train"][random_index]["Question"]
 
77
  origin_answer = dataset["train"][random_index]["Answer"]
 
78
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
79
  print('origin_answer: ' + origin_answer)
80
 
@@ -90,7 +95,7 @@ with col3:
90
 
91
  with col4:
92
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
93
- if st.button("提交"):
94
  if not question:
95
  st.warning("请输入问题!")
96
  else:
 
7
 
8
  # 使用 進擊的巨人 数据集
9
  try:
10
+ converter = pipeline("translation_zh_tw_zh_cn")
11
  dataset = load_dataset("rorubyy/attack_on_titan_wiki_chinese")
12
+ answer_list = [converter(example["Answer"])[0]["translation_text"] for example in dataset["train"]]
13
 
14
  except Exception as e:
15
  st.error(f"读取数据集失败:{e}")
 
68
  temperature = st.number_input("temperature", value=1.0)
69
  max_length = st.number_input("max_length", value=1024)
70
 
71
+ st.divider()
72
+
73
  col3, col4 = st.columns(2)
74
  with col3:
75
+ if st.button("使用原数据集中的随机问题"):
76
  dataset_size = len(dataset["train"])
77
  random_index = random.randint(0, dataset_size - 1)
78
  # 读取随机问题
79
  random_question = dataset["train"][random_index]["Question"]
80
+ random_question = converter(random_question)[0]["translation_text"]
81
  origin_answer = dataset["train"][random_index]["Answer"]
82
+ origin_answer = converter(origin_answer)[0]["translation_text"]
83
  print('[]' + str(random_index) + '/' + str(dataset_size) + ']random_question: ' + random_question)
84
  print('origin_answer: ' + origin_answer)
85
 
 
95
 
96
  with col4:
97
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
98
+ if st.button("提交输入的问题"):
99
  if not question:
100
  st.warning("请输入问题!")
101
  else: