thak123's picture
Update app.py
ca91c16
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr
from ekphrasis.classes.preprocessor import TextPreProcessor
from ekphrasis.classes.tokenizer import SocialTokenizer
from ekphrasis.dicts.emoticons import emoticons
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False)
model.to(device)
# T = tokenizer.TweetTokenizer(
# preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
# text_processor = TextPreProcessor(
# # terms that will be normalized
# normalize=['url', 'email', 'percent', 'money', 'phone', 'user'],
# # terms that will be annotated
# annotate={},
# fix_html=True, # fix HTML tokens
# # corpus from which the word statistics are going to be used
# # for word segmentation
# segmenter="twitter",
# # corpus from which the word statistics are going to be used
# # for spell correction
# corrector="twitter",
# unpack_hashtags=False, # perform word segmentation on hashtags
# unpack_contractions=False, # Unpack contractions (can't -> can not)
# spell_correct_elong=False, # spell correction for elongated words
# # select a tokenizer. You can use SocialTokenizer, or pass your own
# # the tokenizer, should take as input a string and return a list of tokens
# tokenizer=SocialTokenizer(lowercase=True).tokenize,
# # list of dictionaries, for replacing tokens extracted from the text,
# # with other expressions. You can pass more than one dictionaries.
# dicts=[]
# )
social_tokenizer=SocialTokenizer(lowercase=True).tokenize
def preprocess(text):
# tokens = T.tokenize(text)
# tokens = text_processor.pre_process_docs(text)
tokens = social_tokenizer(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def predict_sentiment(sentence = ""):
sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=2
)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
print(outputs)
return label_full_decoder(outputs[0]) #{"label":outputs[0]}
interface = gr.Interface(
fn=predict_sentiment,
inputs='text',
outputs=['label'],
title='Latvian Twitter Sentiment Analysis',
examples= ["Es mīlu Tevi","Es ienīstu kafiju"],
description='Get the positive/neutral/negative sentiment for the given input.'
)
interface.launch(inline = False)