Spaces:
Running
Running
# coding=utf-8 | |
# author: xusong <[email protected]> | |
# time: 2022/8/23 17:08 | |
import time | |
import torch | |
import gradio as gr | |
from info import article | |
from transformers import FillMaskPipeline | |
from transformers import BertTokenizer | |
from kplug.modeling_kplug import KplugForMaskedLM | |
from pycorrector.bert.bert_corrector import BertCorrector | |
from pycorrector import config | |
from loguru import logger | |
device_id = 0 if torch.cuda.is_available() else -1 | |
css = """ | |
.category-legend {display: none !important} | |
""" | |
class KplugCorrector(BertCorrector): | |
def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id): | |
super(BertCorrector, self).__init__() | |
self.name = 'kplug_corrector' | |
t1 = time.time() | |
tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") | |
model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") | |
self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device) | |
if self.model: | |
self.mask = self.model.tokenizer.mask_token | |
logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1)) | |
corrector = KplugCorrector() | |
error_sentences = [ | |
'少先队员因该为老人让坐', | |
'机七学习是人工智能领遇最能体现智能的一个分知', | |
'今天心情很好', | |
] | |
def mock_data(): | |
corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知' | |
errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)] | |
return corrected_sent, errs | |
def correct(sent): | |
""" | |
{"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext | |
""" | |
corrected_sent, errs = corrector.bert_correct(sent) | |
# corrected_sent, errs = mock_data() | |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs)) | |
output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in | |
enumerate(errs)] | |
return {"text": corrected_sent, "entities": output}, errs | |
def test(): | |
for sent in error_sentences: | |
corrected_sent, err = corrector.bert_correct(sent) | |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err)) | |
corr_iface = gr.Interface( | |
fn=correct, | |
inputs=gr.Textbox( | |
label="输入文本", | |
value="少先队员因该为老人让坐"), | |
outputs=[ | |
gr.HighlightedText( | |
label="文本纠错", | |
show_legend=True, | |
), | |
gr.JSON( | |
# label="JSON Output" | |
) | |
], | |
examples=error_sentences, | |
title="文本纠错(Corrector)", | |
description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议', | |
article=article, | |
css=css | |
) | |
if __name__ == "__main__": | |
# test() | |
# correct("少先队员因该为老人让坐") | |
corr_iface.launch() | |