thak123's picture
Update app.py
b2aa5d6
raw
history blame
1.81 kB
import gradio as gr
from transformers import pipeline
from model import BERTBaseUncased
from tokenizer import tokenizer
import torch
from utils import label_full_decoder
import sys
import config
import dataset
import engine
from model import BERTBaseUncased
MODEL = None
T = tokenizer.TweetTokenizer(preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)))
model.to(device)
def preprocess(text):
tokens = T.tokenize(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
sentence = preprocess(sentence)
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=3
)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
print(outputs)
return label_full_decoder(outputs[0])
interface = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs=['text'],
title='Sentiment Analysis',
description='Get the positive/neutral/negative sentiment for the given input.'
)
interface.launch(inline = False)