Spaces:
Runtime error
Runtime error
import gradio as gr | |
from mindformers.trainer import Trainer | |
# 初始化trainer | |
trainer = Trainer( | |
task='token_classification', | |
model='tokcls_bert_base_chinese', | |
) | |
# examples | |
warm_input_data = ["结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。", "清华大学座落于首都北京"] | |
# warm_up | |
trainer.predict( | |
predict_checkpoint= | |
'./tokcls_bert_base_chinese_cluener.ckpt', | |
input_data=warm_input_data) | |
# 数据后处理,将数据转成gr.HighlightedText需要的数据 | |
def post_procces(text, text_list): | |
res = [] | |
cur_index = 0 | |
for item in text_list: | |
res.append((text[cur_index:item["start"]], None)) | |
res.append((text[item["start"]:(item["end"] + 1)], | |
" ".join([item["entity_group"], | |
str(item["score"])]))) | |
cur_index = item["end"] + 1 | |
res.append((text[cur_index:], None)) | |
return res | |
# 预测 | |
def token_classification(text): | |
res_list = trainer.predict(input_data=text) | |
res = post_procces(text, res_list[0]) | |
print(res) | |
return res, res_list | |
# gradio接口 | |
gr.Interface( | |
token_classification, | |
gr.Textbox( | |
label="Text", | |
info="Enter sentence here..xt", | |
lines=3, | |
value="结果上周六他们主场0:3惨败给了中游球队瓦拉多利德,近7个多月以来西甲首次输球。", | |
), | |
# ["highlight", "json"], | |
[ | |
gr.HighlightedText( | |
label="Token Classification", | |
combine_adjacent=True, | |
), | |
gr.JSON() | |
], | |
examples=[*warm_input_data], | |
).launch() | |