import torch from utils import label_full_decoder import sys import dataset import engine from model import BERTBaseUncased from tokenizer import tokenizer import config import gradio as gr DEVICE = config.device # MODEL = BERTBaseUncased() # MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE))) # MODEL.eval() T = tokenizer.TweetTokenizer( preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False) def preprocess(text): tokens = T.tokenize(text) print(tokens, file=sys.stderr) ptokens = [] for index, token in enumerate(tokens): if "@" in token: if index > 0: # check if previous token was mention if "@" in tokens[index-1]: pass else: ptokens.append("mention_0") else: ptokens.append("mention_0") else: ptokens.append(token) print(ptokens, file=sys.stderr) return " ".join(ptokens) def sentence_prediction(sentence): sentence = preprocess(sentence) model_path = config.MODEL_PATH test_dataset = dataset.BERTDataset( review=[sentence], target=[0] ) test_data_loader = torch.utils.data.DataLoader( test_dataset, batch_size=config.VALID_BATCH_SIZE, num_workers=3 ) # device = config.device model = BERTBaseUncased() # model.load_state_dict(torch.load( # model_path, map_location=torch.device(device))) model.to(device) outputs, [] = engine.predict_fn(test_data_loader, MODEL, device) print(outputs) return {"label":outputs[0]} if __name__ == "__main__": demo = gr.Interface( fn=sentence_prediction, inputs=gr.Textbox(placeholder="Enter a sentence here..."), outputs="label", # interpretation="default", examples=[["!"]]) demo.launch(debug = True, enable_queue=True, show_error = True)