File size: 2,638 Bytes
e5ffa90 481c6b3 db2fdad 481c6b3 db2fdad 5f726f0 6649164 98b648b 155ed73 481c6b3 155ed73 481c6b3 31209f4 481c6b3 db2fdad aff6f4b 8cfc5e5 e5ffa90 db2fdad e5ffa90 b3b3581 481c6b3 db2fdad e5ffa90 db2fdad e5ffa90 db2fdad e5ffa90 98b648b e5ffa90 155ed73 481c6b3 155ed73 e5ffa90 155ed73 db2fdad ba0c1c3 db2fdad e5ffa90 db2fdad e5ffa90 d481ecd e42c394 b3b3581 e42c394 b3b3581 e42c394 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
# from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr
# DEVICE = config.device
import requests
URL = "https://huggingface.co/FFZG-cleopatra/bert-emoji-latvian-twitter/blob/main/pytorch_model.bin"
response = requests.get(URL)
open("pytorch_model.bin", "wb").write(response.content)
model_path = "pytorch_model.bin"
# model = AutoModel.from_pretrained("thak123/bert-emoji-latvian-twitter-classifier")
# 7 EPOCH Version
BERT_PATH = "FFZG-cleopatra/bert-emoji-latvian-twitter"
tokenizer = transformers.BertTokenizer.from_pretrained(
BERT_PATH,
do_lower_case=True
)
#AutoTokenizer.from_pretrained("FFZG-cleopatra/bert-emoji-latvian-twitter")
# classifier = pipeline("sentiment-analysis",
# model= model,
# tokenizer = tokenizer)
# MODEL = BERTBaseUncased()
# MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE)))
# MODEL.eval()
# T = tokenizer.TweetTokenizer(
# preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
def preprocess(text):
tokens = T.tokenize(text)
print(tokens, file=sys.stderr)
ptokens = []
for index, token in enumerate(tokens):
if "@" in token:
if index > 0:
# check if previous token was mention
if "@" in tokens[index-1]:
pass
else:
ptokens.append("mention_0")
else:
ptokens.append("mention_0")
else:
ptokens.append(token)
print(ptokens, file=sys.stderr)
return " ".join(ptokens)
def sentence_prediction(sentence):
# sentence = preprocess(sentence)
# model_path = config.MODEL_PATH
# test_dataset = dataset.BERTDataset(
# review=[sentence],
# target=[0]
# )
# test_data_loader = torch.utils.data.DataLoader(
# test_dataset,
# batch_size=config.VALID_BATCH_SIZE,
# num_workers=3
# )
# device = config.device
model = BERTBaseUncased()
model.load_state_dict(torch.load(
model_path, map_location=torch.device(device)))
model.to(device)
outputs, [] = engine.predict_fn(test_data_loader, MODEL, device)
outputs = classifier(sentence)
print(outputs)
return outputs #{"label":outputs[0]}
demo = gr.Interface(
fn=sentence_prediction,
inputs='text',
outputs='label',
)
demo.launch()
|