thak123 commited on
Commit
a5b267d
·
1 Parent(s): 7000b9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -1
app.py CHANGED
@@ -1,15 +1,72 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
 
 
 
 
 
 
 
 
 
4
  from model import BERTBaseUncased
5
 
 
 
6
 
7
  def get_sentiment(input_text):
8
  result = sentiment(input_text)
9
  return f"result: {result[0]['label']}", f"score: {result[0]['score']}"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  interface = gr.Interface(
12
- fn=get_sentiment,
13
  inputs='text',
14
  outputs=['text', 'text'],
15
  title='Sentiment Analysis',
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
+ from model import BERTBaseUncased
5
+ from tokenizer import tokenizer
6
+ import torch
7
+ from utils import label_full_decoder
8
+ import sys
9
+ import config
10
+ import dataset
11
+ import engine
12
  from model import BERTBaseUncased
13
 
14
+ MODEL = None
15
+ DEVICE = config.device
16
 
17
  def get_sentiment(input_text):
18
  result = sentiment(input_text)
19
  return f"result: {result[0]['label']}", f"score: {result[0]['score']}"
20
 
21
+ def preprocess(text):
22
+ tokens = T.tokenize(text)
23
+ print(tokens, file=sys.stderr)
24
+ ptokens = []
25
+ for index, token in enumerate(tokens):
26
+ if "@" in token:
27
+ if index > 0:
28
+ # check if previous token was mention
29
+ if "@" in tokens[index-1]:
30
+ pass
31
+ else:
32
+ ptokens.append("mention_0")
33
+ else:
34
+ ptokens.append("mention_0")
35
+ else:
36
+ ptokens.append(token)
37
+
38
+ print(ptokens, file=sys.stderr)
39
+ return " ".join(ptokens)
40
+
41
+
42
+ def sentence_prediction(sentence):
43
+ sentence = preprocess(sentence)
44
+ model_path = config.MODEL_PATH
45
+
46
+ test_dataset = dataset.BERTDataset(
47
+ review=[sentence],
48
+ target=[0]
49
+ )
50
+
51
+ test_data_loader = torch.utils.data.DataLoader(
52
+ test_dataset,
53
+ batch_size=config.VALID_BATCH_SIZE,
54
+ num_workers=3
55
+ )
56
+
57
+ device = config.device
58
+
59
+ model = BERTBaseUncased()
60
+ model.load_state_dict(torch.load(
61
+ model_path, map_location=torch.device(device)))
62
+ model.to(device)
63
+
64
+ outputs, [] = engine.predict_fn(test_data_loader, model, device)
65
+ print(outputs)
66
+ return outputs[0]
67
+
68
  interface = gr.Interface(
69
+ fn=sentence_prediction,
70
  inputs='text',
71
  outputs=['text', 'text'],
72
  title='Sentiment Analysis',