lwdragon commited on
Commit
ffe799a
·
1 Parent(s): 8509333

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from mindformers.trainer import Trainer
3
+
4
+ # 初始化trainer
5
+ trainer = Trainer(
6
+ task='token_classification',
7
+ model='tokcls_bert_base_chinese',
8
+ )
9
+
10
+ # examples
11
+ warm_input_data = ["结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。", "清华大学座落于首都北京"]
12
+
13
+ # warm_up
14
+ trainer.predict(
15
+ predict_checkpoint=
16
+ '/home/lwdragon/work/transformer-test/checkpoint_download.bak/tokcls/tokcls_bert_base_chinese_cluener.ckpt',
17
+ input_data=warm_input_data)
18
+
19
+
20
+ # 数据后处理,将数据转成gr.HighlightedText需要的数据
21
+ def post_procces(text, text_list):
22
+ res = []
23
+ cur_index = 0
24
+ for item in text_list:
25
+ res.append((text[cur_index:item["start"]], None))
26
+ res.append((text[item["start"]:(item["end"] + 1)],
27
+ " ".join([item["entity_group"],
28
+ str(item["score"])])))
29
+ cur_index = item["end"] + 1
30
+ res.append((text[cur_index:], None))
31
+ return res
32
+
33
+
34
+ # 预测
35
+ def token_classification(text):
36
+ res_list = trainer.predict(input_data=text)
37
+ res = post_procces(text, res_list[0])
38
+ print(res)
39
+ return res, res_list
40
+
41
+
42
+ # gradio接口
43
+ gr.Interface(
44
+ token_classification,
45
+ gr.Textbox(
46
+ label="Text",
47
+ info="Enter sentence here..xt",
48
+ lines=3,
49
+ value="结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。",
50
+ ),
51
+ # ["highlight", "json"],
52
+ [
53
+ gr.HighlightedText(
54
+ label="Token Classification",
55
+ combine_adjacent=True,
56
+ ),
57
+ gr.JSON()
58
+ ],
59
+ examples=[*warm_input_data],
60
+ ).launch()