danseith commited on
Commit
eabdff9
·
1 Parent(s): 85ace6b

Minor typo fix

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -30,10 +30,12 @@ def add_mask(text, size=1):
30
  if '[MASK]' in split_text:
31
  return text
32
  idx = np.random.randint(len(split_text), size=size)
33
- masked = split_text[idx]
34
  for i in idx:
 
35
  split_text[i] = '[MASK]'
36
- return ' '.join(split_text), masked
 
37
 
38
 
39
  class TempScalePipe(FillMaskPipeline):
@@ -136,9 +138,10 @@ scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
136
  def unmask(text, temp, rounds):
137
  sampling = 'multi'
138
  for round in range(rounds):
139
- text, masked = add_mask(text, size=1)
140
- split_text = text.split()
141
- res = scrambler(text, temp=temp, top_k=10)
 
142
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
143
  out = {item["token_str"]: item["score"] for item in res}
144
  score_to_str = {out[k] : k for k in out.keys()}
@@ -149,7 +152,7 @@ def unmask(text, temp, rounds):
149
  idx = np.random.randint(0, len(score_list))
150
  score = score_list[idx]
151
  new_token = score_to_str[score]
152
- if len(list(new_token)) < 2 or new_token == masked:
153
  continue
154
  split_text[mask_pos] = '*' + new_token + '*'
155
  text = ' '.join(split_text)
 
30
  if '[MASK]' in split_text:
31
  return text
32
  idx = np.random.randint(len(split_text), size=size)
33
+ masked_strings = []
34
  for i in idx:
35
+ masked_strings.append(split_text[i])
36
  split_text[i] = '[MASK]'
37
+ masked_output = ' '.join(split_text)
38
+ return masked_output, masked_strings
39
 
40
 
41
  class TempScalePipe(FillMaskPipeline):
 
138
  def unmask(text, temp, rounds):
139
  sampling = 'multi'
140
  for round in range(rounds):
141
+ tp = add_mask(text, size=1)
142
+ masked_text, masked = tp[0], tp[1]
143
+ split_text = masked_text.split()
144
+ res = scrambler(masked_text, temp=temp, top_k=10)
145
  mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
146
  out = {item["token_str"]: item["score"] for item in res}
147
  score_to_str = {out[k] : k for k in out.keys()}
 
152
  idx = np.random.randint(0, len(score_list))
153
  score = score_list[idx]
154
  new_token = score_to_str[score]
155
+ if len(list(new_token)) < 2 or new_token == masked[0]:
156
  continue
157
  split_text[mask_pos] = '*' + new_token + '*'
158
  text = ' '.join(split_text)