|
import torch |
|
from utils import label_full_decoder |
|
import sys |
|
import dataset |
|
import engine |
|
from model import BERTBaseUncased |
|
|
|
import config |
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
import gradio as gr |
|
|
|
from ekphrasis.classes.preprocessor import TextPreProcessor |
|
from ekphrasis.classes.tokenizer import SocialTokenizer |
|
from ekphrasis.dicts.emoticons import emoticons |
|
|
|
device = config.device |
|
model = BERTBaseUncased() |
|
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False) |
|
model.to(device) |
|
|
|
|
|
|
|
text_processor = TextPreProcessor( |
|
|
|
normalize=['url', 'email', 'percent', 'money', 'phone', 'user'], |
|
|
|
annotate={}, |
|
fix_html=True, |
|
|
|
|
|
|
|
segmenter="twitter", |
|
|
|
|
|
|
|
corrector="twitter", |
|
|
|
unpack_hashtags=False, |
|
unpack_contractions=False, |
|
spell_correct_elong=False, |
|
|
|
|
|
|
|
tokenizer=SocialTokenizer(lowercase=True).tokenize, |
|
|
|
|
|
|
|
dicts=[] |
|
) |
|
|
|
|
|
|
|
|
|
def preprocess(text): |
|
|
|
tokens = text_processor.pre_process_docs(text) |
|
print(tokens, file=sys.stderr) |
|
ptokens = [] |
|
for index, token in enumerate(tokens): |
|
if "@" in token: |
|
if index > 0: |
|
|
|
if "@" in tokens[index-1]: |
|
pass |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append(token) |
|
|
|
print(ptokens, file=sys.stderr) |
|
return " ".join(ptokens) |
|
|
|
|
|
def sentence_prediction(sentence): |
|
sentence = preprocess(sentence) |
|
|
|
model_path = config.MODEL_PATH |
|
|
|
test_dataset = dataset.BERTDataset( |
|
review=[sentence], |
|
target=[0] |
|
) |
|
|
|
test_data_loader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=config.VALID_BATCH_SIZE, |
|
num_workers=2 |
|
) |
|
|
|
outputs, [] = engine.predict_fn(test_data_loader, model, device) |
|
|
|
outputs = classifier(sentence) |
|
|
|
print(outputs) |
|
return outputs |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=sentence_prediction, |
|
inputs='text', |
|
outputs='label', |
|
) |
|
|
|
demo.launch() |
|
|
|
|