import torch from utils import label_full_decoder import sys import dataset import engine from model import BERTBaseUncased # from tokenizer import tokenizer import config from transformers import pipeline, AutoTokenizer, AutoModel import gradio as gr # T = tokenizer.TweetTokenizer( # preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False) # def preprocess(text): # tokens = T.tokenize(text) # print(tokens, file=sys.stderr) # ptokens = [] # for index, token in enumerate(tokens): # if "@" in token: # if index > 0: # # check if previous token was mention # if "@" in tokens[index-1]: # pass # else: # ptokens.append("mention_0") # else: # ptokens.append("mention_0") # else: # ptokens.append(token) # print(ptokens, file=sys.stderr) # return " ".join(ptokens) def sentence_prediction(sentence): # sentence = preprocess(sentence) model_path = config.MODEL_PATH test_dataset = dataset.BERTDataset( review=[sentence], target=[0] ) test_data_loader = torch.utils.data.DataLoader( test_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=-1 ) device = config.device model = BERTBaseUncased() model.load_state_dict(torch.load( model_path, map_location=torch.device(device))) model.to(device) outputs, [] = engine.predict_fn(test_data_loader, model, device) outputs = classifier(sentence) print(outputs) return outputs #{"label":outputs[0]} demo = gr.Interface( fn=sentence_prediction, inputs='text', outputs='label', ) demo.launch()