Spaces:
Running
Running
# coding=utf-8 | |
# author: xusong <[email protected]> | |
# time: 2022/8/23 17:08 | |
""" | |
https://github.com/gradio-app/gradio/blob/299ba1bd1aed8040b3087c06c10fedf75901f91f/gradio/external.py#L484 | |
interface = gr.Interface.load( | |
"models/bert-base-uncased", api_key=None, alias="fill-mask" | |
) | |
## TODO: | |
1. json_output | |
2. 百分数换成小数 | |
3. | |
""" | |
import gradio as gr | |
from info import article | |
from transformers import FillMaskPipeline | |
from transformers import BertTokenizer | |
from kplug.modeling_kplug import KplugForMaskedLM | |
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") | |
model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") | |
# fill mask | |
def fill_mask(text): | |
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer) | |
outputs = fill_masker(text) | |
return {i["token_str"]: i["score"] for i in outputs} | |
mlm_examples = [ | |
"这款连[MASK]裙真漂亮", | |
"这是杨[MASK]同款包包,精选优质皮料制作", | |
"美颜去痘洁面[MASK]", | |
] | |
mlm_iface = gr.Interface( | |
fn=fill_mask, | |
inputs=gr.Textbox( | |
label="输入文本", | |
value="这款连[MASK]裙真漂亮"), | |
# outputs=gr.Label( | |
outputs=gr.Label( | |
label="填词", | |
), | |
examples=mlm_examples, | |
title="文本填词(Fill Mask)", | |
description='基于KPLUG预训练语言模型', | |
article=article | |
) | |
if __name__ == "__main__": | |
mlm_iface.launch() | |