thak123's picture
Update app.py
a5b267d
raw
history blame
1.89 kB
import gradio as gr
from transformers import pipeline
from model import BERTBaseUncased
from tokenizer import tokenizer
import torch
from utils import label_full_decoder
import sys
import config
import dataset
import engine
from model import BERTBaseUncased
MODEL = None
DEVICE = config.device
def get_sentiment(input_text):
result = sentiment(input_text)
return f"result: {result[0]['label']}", f"score: {result[0]['score']}"
def preprocess(text):
tokens = T.tokenize(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=3
)
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(
model_path, map_location=torch.device(device)))
model.to(device)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
print(outputs)
return outputs[0]
interface = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs=['text', 'text'],
title='Sentiment Analysis',
description='Get the positive/negative sentiment for the given input.'
)
interface.launch(inline = False)