thak123 commited on
Commit
98b648b
·
1 Parent(s): af3ae7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import sys
3
  import dataset
4
  import engine
@@ -8,6 +9,9 @@ import config
8
 
9
  import gradio as gr
10
 
 
 
 
11
 
12
  T = tokenizer.TweetTokenizer(
13
  preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
@@ -48,24 +52,30 @@ def sentence_prediction(sentence):
48
  num_workers=3
49
  )
50
 
51
- device = config.device
52
 
53
- model = BERTBaseUncased()
54
- model.load_state_dict(torch.load(
55
- model_path, map_location=torch.device(device)))
56
- model.to(device)
57
 
58
  outputs, [] = engine.predict_fn(test_data_loader, model, device)
59
  print(outputs)
60
  return {"label":outputs[0]}
61
 
62
- demo = gr.Interface(
63
- fn=sentence_prediction,
64
- inputs=gr.Textbox(placeholder="Enter a sentence here..."),
65
- # outputs="label",
66
- # interpretation="default",
67
- examples=[["!"]])
68
-
69
- demo.launch(debug = True,
70
- enable_queue=True,
71
- show_error = True)
 
 
 
 
 
 
 
1
  import torch
2
+ from utils import label_full_decoder
3
  import sys
4
  import dataset
5
  import engine
 
9
 
10
  import gradio as gr
11
 
12
+ MODEL = None
13
+ DEVICE = config.device
14
+
15
 
16
  T = tokenizer.TweetTokenizer(
17
  preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
 
52
  num_workers=3
53
  )
54
 
55
+ # device = config.device
56
 
57
+ # model = BERTBaseUncased()
58
+ # model.load_state_dict(torch.load(
59
+ # model_path, map_location=torch.device(device)))
60
+ # model.to(device)
61
 
62
  outputs, [] = engine.predict_fn(test_data_loader, model, device)
63
  print(outputs)
64
  return {"label":outputs[0]}
65
 
66
+ if __name__ == "__main__":
67
+ MODEL = BERTBaseUncased()
68
+ MODEL.load_state_dict(torch.load(
69
+ config.MODEL_PATH, map_location=torch.device(DEVICE)))
70
+ MODEL.eval()
71
+
72
+ demo = gr.Interface(
73
+ fn=sentence_prediction,
74
+ inputs=gr.Textbox(placeholder="Enter a sentence here..."),
75
+ outputs="label",
76
+ # interpretation="default",
77
+ examples=[["!"]])
78
+
79
+ demo.launch(debug = True,
80
+ enable_queue=True,
81
+ show_error = True)