|
import torch |
|
from utils import label_full_decoder |
|
import sys |
|
import dataset |
|
import engine |
|
from model import BERTBaseUncased |
|
|
|
import config |
|
from transformers import pipeline, AutoTokenizer, AutoModel |
|
import gradio as gr |
|
|
|
|
|
|
|
import requests |
|
URL = "https://huggingface.co/FFZG-cleopatra/bert-emoji-latvian-twitter/blob/main/pytorch_model.bin" |
|
response = requests.get(URL) |
|
open("pytorch_model.bin", "wb").write(response.content) |
|
|
|
model_path = "pytorch_model.bin" |
|
|
|
|
|
|
|
BERT_PATH = "FFZG-cleopatra/bert-emoji-latvian-twitter" |
|
|
|
tokenizer = transformers.BertTokenizer.from_pretrained( |
|
BERT_PATH, |
|
do_lower_case=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(text): |
|
tokens = T.tokenize(text) |
|
print(tokens, file=sys.stderr) |
|
ptokens = [] |
|
for index, token in enumerate(tokens): |
|
if "@" in token: |
|
if index > 0: |
|
|
|
if "@" in tokens[index-1]: |
|
pass |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append("mention_0") |
|
else: |
|
ptokens.append(token) |
|
|
|
print(ptokens, file=sys.stderr) |
|
return " ".join(ptokens) |
|
|
|
|
|
def sentence_prediction(sentence): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = BERTBaseUncased() |
|
model.load_state_dict(torch.load( |
|
model_path, map_location=torch.device(device))) |
|
model.to(device) |
|
|
|
outputs, [] = engine.predict_fn(test_data_loader, MODEL, device) |
|
|
|
outputs = classifier(sentence) |
|
|
|
print(outputs) |
|
return outputs |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=sentence_prediction, |
|
inputs='text', |
|
outputs='label', |
|
) |
|
|
|
demo.launch() |
|
|
|
|