|
import torch |
|
from utils import label_full_decoder |
|
import sys |
|
import dataset |
|
import engine |
|
from model import BERTBaseUncased |
|
|
|
import config |
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sentence_prediction(sentence): |
|
|
|
|
|
model_path = config.MODEL_PATH |
|
|
|
test_dataset = dataset.BERTDataset( |
|
review=[sentence], |
|
target=[0] |
|
) |
|
|
|
test_data_loader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=config.VALID_BATCH_SIZE, |
|
num_workers=-1 |
|
) |
|
|
|
device = config.device |
|
|
|
model = BERTBaseUncased() |
|
model.load_state_dict(torch.load( |
|
model_path, map_location=torch.device(device))) |
|
model.to(device) |
|
|
|
outputs, [] = engine.predict_fn(test_data_loader, model, device) |
|
|
|
outputs = classifier(sentence) |
|
|
|
print(outputs) |
|
return outputs |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=sentence_prediction, |
|
inputs='text', |
|
outputs='label', |
|
) |
|
|
|
demo.launch() |
|
|
|
|