thak123's picture
Update app.py
ce26deb
raw
history blame
1.79 kB
# import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
# from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr
# T = tokenizer.TweetTokenizer(
# preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
# def preprocess(text):
# tokens = T.tokenize(text)
# print(tokens, file=sys.stderr)
# ptokens = []
# for index, token in enumerate(tokens):
# if "@" in token:
# if index > 0:
# # check if previous token was mention
# if "@" in tokens[index-1]:
# pass
# else:
# ptokens.append("mention_0")
# else:
# ptokens.append("mention_0")
# else:
# ptokens.append(token)
# print(ptokens, file=sys.stderr)
# return " ".join(ptokens)
def sentence_prediction(sentence):
# sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=-1
)
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(
model_path, map_location=torch.device(device)))
model.to(device)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
outputs = classifier(sentence)
print(outputs)
return outputs #{"label":outputs[0]}
demo = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs='label',
)
demo.launch()