File size: 2,463 Bytes
e5ffa90
db2fdad
272a9cc
db2fdad
 
 
 
 
 
5f726f0
 
6649164
98b648b
19254fe
 
7f32c66
19254fe
 
 
db2fdad
aff6f4b
 
 
8cfc5e5
 
e5ffa90
db2fdad
 
e5ffa90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2fdad
e5ffa90
db2fdad
 
 
 
e5ffa90
db2fdad
 
 
 
 
e5ffa90
98b648b
e5ffa90
db2fdad
 
 
 
e5ffa90
db2fdad
 
ba0c1c3
db2fdad
e5ffa90
db2fdad
e5ffa90
d481ecd
6649164
272a9cc
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
# from utils import label_full_decoder
# import sys
# import dataset
# import engine
# from model import BERTBaseUncased
# from tokenizer import tokenizer
# import config
from transformers import pipeline, AutoTokenizer, AutoModel
import gradio as gr

# DEVICE = config.device

# model = AutoModel.from_pretrained("thak123/bert-emoji-latvian-twitter-classifier")
# tokenizer = AutoTokenizer.from_pretrained("FFZG-cleopatra/bert-emoji-latvian-twitter")

# classifier = pipeline("sentiment-analysis",
#                       model= model, 
#                       tokenizer = tokenizer)

# MODEL = BERTBaseUncased()
# MODEL.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(DEVICE)))
# MODEL.eval()



# T = tokenizer.TweetTokenizer(
#     preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

def preprocess(text):
    tokens = T.tokenize(text)
    print(tokens, file=sys.stderr)
    ptokens = []
    for index, token in enumerate(tokens):
        if "@" in token:
            if index > 0:
                # check if previous token was mention
                if "@" in tokens[index-1]:
                    pass
                else:
                    ptokens.append("mention_0")
            else:
                ptokens.append("mention_0")
        else:
            ptokens.append(token)

    print(ptokens, file=sys.stderr)
    return " ".join(ptokens)


def sentence_prediction(sentence):
    sentence = preprocess(sentence)
    # model_path = config.MODEL_PATH

    # test_dataset = dataset.BERTDataset(
    #     review=[sentence],
    #     target=[0]
    # )

    # test_data_loader = torch.utils.data.DataLoader(
    #     test_dataset,
    #     batch_size=config.VALID_BATCH_SIZE,
    #     num_workers=3
    # )

    # device = config.device

    # model = BERTBaseUncased()
    # # model.load_state_dict(torch.load(
    # #     model_path, map_location=torch.device(device)))
    # model.to(device)

    # outputs, [] = engine.predict_fn(test_data_loader, MODEL, device)

    outputs =  classifier(sentence)
    
    print(outputs)
    return outputs #{"label":outputs[0]}


     
# demo = gr.Interface(
#     fn=sentence_prediction, 
#     inputs=gr.Textbox(placeholder="Enter a  sentence here..."), 
#     outputs="label", 
#     interpretation="default",
#     examples=[["!"]])

# demo.launch()

gr.Interface(fn=sentence_prediction,inputs="text",outputs="label").launch()