File size: 1,812 Bytes
e710478
 
 
a5b267d
 
 
 
 
 
 
 
4b096e3
e710478
a5b267d
55b119f
c0f15ba
e710478
55b119f
 
 
b2aa5d6
55b119f
e710478
a5b267d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b119f
a5b267d
 
 
 
 
 
 
 
 
 
 
 
 
 
55b119f
a5b267d
e710478
a5b267d
e710478
11504f2
e710478
55b119f
e710478
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from transformers import pipeline

from model import BERTBaseUncased
from tokenizer import tokenizer
import torch
from utils import label_full_decoder
import sys
import config
import dataset
import engine
from model import BERTBaseUncased

MODEL = None

T = tokenizer.TweetTokenizer(preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)

device = config.device

model = BERTBaseUncased()
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)))
model.to(device)

def preprocess(text):
    tokens = T.tokenize(text)
    print(tokens, file=sys.stderr)
    ptokens = []
    for index, token in enumerate(tokens):
        if "@" in token:
            if index > 0:
                # check if previous token was mention
                if "@" in tokens[index-1]:
                    pass
                else:
                    ptokens.append("mention_0")
            else:
                ptokens.append("mention_0")
        else:
            ptokens.append(token)

    print(ptokens, file=sys.stderr)
    return " ".join(ptokens)


def sentence_prediction(sentence):
    sentence = preprocess(sentence)
   

    test_dataset = dataset.BERTDataset(
        review=[sentence],
        target=[0]
    )

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=3
    )

    outputs, [] = engine.predict_fn(test_data_loader, model, device)
    print(outputs)
    return  label_full_decoder(outputs[0])
    
interface = gr.Interface(
    fn=sentence_prediction,
    inputs='text',
    outputs=['text'],
    title='Sentiment Analysis',
    description='Get the positive/neutral/negative sentiment for the given input.'
)


interface.launch(inline = False)