thak123 commited on
Commit
623670e
1 Parent(s): 1c9935e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -25
app.py CHANGED
@@ -4,7 +4,7 @@ import sys
4
  import dataset
5
  import engine
6
  from model import BERTBaseUncased
7
- # from tokenizer import tokenizer
8
  import config
9
  from transformers import pipeline, AutoTokenizer, AutoModel
10
  import gradio as gr
@@ -14,32 +14,32 @@ model = BERTBaseUncased()
14
  model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False)
15
  model.to(device)
16
 
17
- # T = tokenizer.TweetTokenizer(
18
- # preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
19
-
20
- # def preprocess(text):
21
- # tokens = T.tokenize(text)
22
- # print(tokens, file=sys.stderr)
23
- # ptokens = []
24
- # for index, token in enumerate(tokens):
25
- # if "@" in token:
26
- # if index > 0:
27
- # # check if previous token was mention
28
- # if "@" in tokens[index-1]:
29
- # pass
30
- # else:
31
- # ptokens.append("mention_0")
32
- # else:
33
- # ptokens.append("mention_0")
34
- # else:
35
- # ptokens.append(token)
36
-
37
- # print(ptokens, file=sys.stderr)
38
- # return " ".join(ptokens)
39
 
40
 
41
  def sentence_prediction(sentence):
42
- # sentence = preprocess(sentence)
43
 
44
  model_path = config.MODEL_PATH
45
 
@@ -51,7 +51,7 @@ def sentence_prediction(sentence):
51
  test_data_loader = torch.utils.data.DataLoader(
52
  test_dataset,
53
  batch_size=config.VALID_BATCH_SIZE,
54
- num_workers=-1
55
  )
56
 
57
  outputs, [] = engine.predict_fn(test_data_loader, model, device)
 
4
  import dataset
5
  import engine
6
  from model import BERTBaseUncased
7
+ from tokenizer import tokenizer
8
  import config
9
  from transformers import pipeline, AutoTokenizer, AutoModel
10
  import gradio as gr
 
14
  model.load_state_dict(torch.load(config.MODEL_PATH, map_location=torch.device(device)),strict=False)
15
  model.to(device)
16
 
17
+ T = tokenizer.TweetTokenizer(
18
+ preserve_handles=True, preserve_hashes=True, preserve_case=False, preserve_url=False)
19
+
20
+ def preprocess(text):
21
+ tokens = T.tokenize(text)
22
+ print(tokens, file=sys.stderr)
23
+ ptokens = []
24
+ for index, token in enumerate(tokens):
25
+ if "@" in token:
26
+ if index > 0:
27
+ # check if previous token was mention
28
+ if "@" in tokens[index-1]:
29
+ pass
30
+ else:
31
+ ptokens.append("mention_0")
32
+ else:
33
+ ptokens.append("mention_0")
34
+ else:
35
+ ptokens.append(token)
36
+
37
+ print(ptokens, file=sys.stderr)
38
+ return " ".join(ptokens)
39
 
40
 
41
  def sentence_prediction(sentence):
42
+ sentence = preprocess(sentence)
43
 
44
  model_path = config.MODEL_PATH
45
 
 
51
  test_data_loader = torch.utils.data.DataLoader(
52
  test_dataset,
53
  batch_size=config.VALID_BATCH_SIZE,
54
+ num_workers=2
55
  )
56
 
57
  outputs, [] = engine.predict_fn(test_data_loader, model, device)