thak123's picture
Update app.py
de2bdfb
raw
history blame
3.01 kB
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr
from ekphrasis.classes.preprocessor import TextPreProcessor
from ekphrasis.classes.tokenizer import SocialTokenizer
from ekphrasis.dicts.emoticons import emoticons
device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False)
model.to(device)
text_processor = TextPreProcessor(
# terms that will be normalized
normalize=['url', 'email', 'percent', 'money', 'phone', 'user'],
# terms that will be annotated
annotate={},
fix_html=True, # fix HTML tokens
# corpus from which the word statistics are going to be used
# for word segmentation
segmenter="twitter",
# corpus from which the word statistics are going to be used
# for spell correction
corrector="twitter",
unpack_hashtags=False, # perform word segmentation on hashtags
unpack_contractions=False, # Unpack contractions (can't -> can not)
spell_correct_elong=False, # spell correction for elongated words
# select a tokenizer. You can use SocialTokenizer, or pass your own
# the tokenizer, should take as input a string and return a list of tokens
tokenizer=SocialTokenizer(lowercase=True).tokenize,
# list of dictionaries, for replacing tokens extracted from the text,
# with other expressions. You can pass more than one dictionaries.
dicts=[]
)
# T = tokenizer.TweetTokenizer(
# preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
def preprocess(text):
# tokens = T.tokenize(text)
tokens = text_processor.pre_process_docs(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
sentence = preprocess(sentence)
model_path = config.MODEL_PATH
test_dataset = dataset.BERTDataset(
review=[sentence],
target=[0]
)
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=config.VALID_BATCH_SIZE,
num_workers=2
)
outputs, [] = engine.predict_fn(test_data_loader, model, device)
outputs = classifier(sentence)
print(outputs)
return outputs #{"label":outputs[0]}
demo = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs='label',
)
demo.launch()