|
import torch |
|
from utils import label_full_decoder |
|
import sys |
|
import dataset |
|
import engine |
|
from model import BERTBaseUncased |
|
|
|
import config |
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
import gradio as gr |
|
|
|
from ekphrasis.classes.preprocessor import TextPreProcessor |
|
from ekphrasis.classes.tokenizer import SocialTokenizer |
|
from ekphrasis.dicts.emoticons import emoticons |
|
|
|
device = config.device |
|
model = BERTBaseUncased() |
|
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False) |
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
social_tokenizer=SocialTokenizer(lowercase=True).tokenize |
|
|
|
def preprocess(text): |
|
|
|
|
|
|
|
tokens = social_tokenizer(text) |
|
print(tokens, file=sys.stderr) |
|
ptokens = [] |
|
for index, token in enumerate(tokens): |
|
if "@" in token: |
|
if index > 0: |
|
|
|
if "@" in tokens[index-1]: |
|
pass |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append(token) |
|
|
|
print(ptokens, file=sys.stderr) |
|
return " ".join(ptokens) |
|
|
|
|
|
def predict_sentiment(sentence): |
|
sentence = preprocess(sentence) |
|
|
|
model_path = config.MODEL_PATH |
|
|
|
test_dataset = dataset.BERTDataset( |
|
review=[sentence], |
|
target=[0] |
|
) |
|
|
|
test_data_loader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=config.VALID_BATCH_SIZE, |
|
num_workers=2 |
|
) |
|
|
|
outputs, [] = engine.predict_fn(test_data_loader, model, device) |
|
|
|
outputs = classifier(sentence) |
|
|
|
print(outputs) |
|
return outputs |
|
|
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_sentiment, |
|
inputs='text', |
|
outputs=['label'], |
|
title='Latvian Twitter Sentiment Analysis', |
|
examples= ["Es mīlu Tevi","Es ienīstu kafiju"], |
|
description='Get the positive/neutral/negative sentiment for the given input.' |
|
) |
|
|
|
interface.launch(inline = False) |
|
|
|
|