|
import torch |
|
import sys |
|
import dataset |
|
import engine |
|
from model import BERTBaseUncased |
|
from tokenizer import tokenizer |
|
import config |
|
|
|
|
|
T = tokenizer.TweetTokenizer( |
|
preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False) |
|
|
|
def preprocess(text): |
|
tokens = T.tokenize(text) |
|
print(tokens, file=sys.stderr) |
|
ptokens = [] |
|
for index, token in enumerate(tokens): |
|
if "@" in token: |
|
if index > 0: |
|
|
|
if "@" in tokens[index-1]: |
|
pass |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append(token) |
|
|
|
print(ptokens, file=sys.stderr) |
|
return " ".join(ptokens) |
|
|
|
|
|
def sentence_prediction(sentence): |
|
sentence = preprocess(sentence) |
|
model_path = config.MODEL_PATH |
|
|
|
test_dataset = dataset.BERTDataset( |
|
review=[sentence], |
|
target=[0] |
|
) |
|
|
|
test_data_loader = torch.utils.data.DataLoader( |
|
test_dataset, |
|
batch_size=config.VALID_BATCH_SIZE, |
|
num_workers=3 |
|
) |
|
|
|
device = config.device |
|
|
|
model = BERTBaseUncased() |
|
model.load_state_dict(torch.load( |
|
model_path, map_location=torch.device(device))) |
|
model.to(device) |
|
|
|
outputs, [] = engine.predict_fn(test_data_loader, model, device) |
|
print(outputs) |
|
return {"label":outputs[0]} |
|
|
|
demo = gr.Interface( |
|
fn=sentence_prediction, |
|
inputs=gr.Textbox(placeholder="Enter a sentence here..."), |
|
outputs="label", |
|
|
|
examples=[["!"]]) |
|
|
|
demo.launch() |
|
|