AshtonIsNotHere commited on
Commit
b840e20
1 Parent(s): 2e85755

Fix to allow masked token after 512th token

Browse files

Sequences longer than 510 are now truncated around the masked token for xlm-roberta-base, regardless of mask location.

Files changed (1) hide show
  1. app.py +29 -2
app.py CHANGED
@@ -31,8 +31,35 @@ xlmr_tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base', max_length=51
31
  xlmr_p = pipeline("fill-mask", model=model, tokenizer=tokenizer)
32
 
33
  def xlmr_base_fn(text):
34
- text = ' '.join(text.split()[:500])
35
- preds = xlmr_p(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  pred_dict = {}
37
  for pred in preds:
38
  pred_dict[pred['token_str']] = pred['score']
 
31
  xlmr_p = pipeline("fill-mask", model=model, tokenizer=tokenizer)
32
 
33
  def xlmr_base_fn(text):
34
+ # Find our masked token
35
+ tokens = xlmr_tokenizer.tokenize(text)
36
+ mask_token_idx = [i for i, x in enumerate(tokens) if xlmr_tokenizer.mask_token in x][0]
37
+
38
+ max_len = tokenizer.model_max_length
39
+ max_len = max_len-2 if max_len % 512 == 0 and max_len < 4096 else 510
40
+
41
+ # Smart truncation for long sequences
42
+ if not len(tokens) < max_len:
43
+
44
+ # Find left and right bounds for truncated sequences
45
+ lbound = max(0, mask_token_idx-(max_len//2))
46
+ rbound = min(len(tokens), mask_token_idx+(max_len//2))
47
+
48
+ # If we hit an edge, expand sequence in the other direction
49
+ if lbound == 0 and rbound != len(tokens)-1:
50
+ rbound = min(len(tokens), max_len)
51
+ elif rbound == len(tokens) and lbound != 0:
52
+ lbound = max(0, len(tokens)-max_len)
53
+
54
+ # Apply truncation and rejoin tokens to form new text
55
+ truncated_text = ''.join(tokens[lbound:rbound])
56
+
57
+ # Handle lowbar from xlmr tokenizer
58
+ truncated_text = ''.join([x if ord(x) != 9601 else ' ' for x in result])
59
+ else:
60
+ truncated_text = text
61
+
62
+ preds = xlmr_p(truncated_text)
63
  pred_dict = {}
64
  for pred in preds:
65
  pred_dict[pred['token_str']] = pred['score']