File size: 2,411 Bytes
e5ffa90
481c6b3
 
 
 
 
db2fdad
481c6b3
db2fdad
5f726f0
 
6649164
98b648b
481c6b3
31209f4
6d9deb3
31209f4
6d9deb3
 
 
 
31209f4
 
481c6b3
 
 
db2fdad
aff6f4b
 
 
8cfc5e5
 
e5ffa90
db2fdad
 
e5ffa90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b3581
481c6b3
db2fdad
e5ffa90
db2fdad
 
 
 
e5ffa90
db2fdad
 
 
 
 
e5ffa90
98b648b
e5ffa90
155ed73
481c6b3
 
155ed73
e5ffa90
155ed73
db2fdad
ba0c1c3
db2fdad
e5ffa90
db2fdad
e5ffa90
d481ecd
e42c394
 
 
b3b3581
e42c394
b3b3581
e42c394
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from utils import label_full_decoder
import sys
import dataset
import engine
from model import BERTBaseUncased
# from tokenizer import tokenizer
import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr

# DEVICE = config.device

# model = AutoModel.from_pretrained("thak123/bert-emoji-latvian-twitter-classifier")
# 7 EPOCH Version
# BERT_PATH = "FFZG-cleopatra/bert-emoji-latvian-twitter"

# tokenizer = transformers.BertTokenizer.from_pretrained(
#     BERT_PATH,
#     do_lower_case=True
# )
#AutoTokenizer.from_pretrained("FFZG-cleopatra/bert-emoji-latvian-twitter")

# classifier = pipeline("sentiment-analysis",
#                       model= model, 
#                       tokenizer = tokenizer)

# MODEL = BERTBaseUncased()
# MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE)))
# MODEL.eval()



# T = tokenizer.TweetTokenizer(
#     preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

def preprocess(text):
    tokens = T.tokenize(text)
    print(tokens, file=sys.stderr)
    ptokens = []
    for index, token in enumerate(tokens):
        if "@" in token:
            if index > 0:
                # check if previous token was mention
                if "@" in tokens[index-1]:
                    pass
                else:
                    ptokens.append("mention_0")
            else:
                ptokens.append("mention_0")
        else:
            ptokens.append(token)

    print(ptokens, file=sys.stderr)
    return " ".join(ptokens)


def sentence_prediction(sentence):
    # sentence = preprocess(sentence)
    
    # model_path = config.MODEL_PATH

    # test_dataset = dataset.BERTDataset(
    #     review=[sentence],
    #     target=[0]
    # )

    # test_data_loader = torch.utils.data.DataLoader(
    #     test_dataset,
    #     batch_size=config.VALID_BATCH_SIZE,
    #     num_workers=3
    # )

    # device = config.device

    model = BERTBaseUncased()
    model.load_state_dict(torch.load(
        model_path, map_location=torch.device(device)))
    model.to(device)

    outputs, [] = engine.predict_fn(test_data_loader, MODEL, device)

    outputs =  classifier(sentence)
    
    print(outputs)
    return outputs #{"label":outputs[0]}




demo = gr.Interface(
  fn=sentence_prediction, 
  inputs='text',
  outputs='label',
)

demo.launch()