chenjg commited on
Commit
129cc39
·
1 Parent(s): e406754

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -25,28 +25,32 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=0.0, filter_value=-float('Inf'
25
  logits[indices_to_remove] = filter_value
26
  return logits
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def generate(input_text):
29
- result = []
30
- for i in range(0,5):
31
- generated = []
32
- for _ in range(100):
33
- input_ids = [tokenizer.cls_token_id]
34
- input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
35
- input_ids = torch.tensor( [input_ids] )
36
-
37
- output = model(input_ids)
38
-
39
- next_token_logits = output.logits[0, -1, :]
40
- next_token_logits[ tokenizer.convert_tokens_to_ids('[UNK]') ] = -float('Inf')
41
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=8, top_p=1)
42
- next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
43
- if next_token == tokenizer.sep_token_id:
44
- break
45
- generated.append( next_token.item() )
46
- input_ids = torch.cat( (input_ids, next_token.unsqueeze(0)), dim=1 )
47
- result.append("".join(tokenizer.convert_ids_to_tokens(generated)));
48
-
49
- return "|".join( result )
50
 
51
  if __name__ == "__main__":
52
 
@@ -54,5 +58,4 @@ if __name__ == "__main__":
54
  fn=generate,
55
  inputs="text",
56
  outputs="text"
57
- ).launch()
58
-
 
25
  logits[indices_to_remove] = filter_value
26
  return logits
27
 
28
+ def generate0(input_text):
29
+ input_ids = [tokenizer.cls_token_id]
30
+ input_ids.extend( tokenizer.encode(input_text + "-", add_special_tokens=False) )
31
+ input_ids = torch.tensor( [input_ids] )
32
+
33
+ generated = []
34
+ for _ in range(100):
35
+ output = model(input_ids)
36
+
37
+ next_token_logits = output.logits[0, -1, :]
38
+ next_token_logits[ tokenizer.convert_tokens_to_ids('[UNK]') ] = -float('Inf')
39
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=8, top_p=1)
40
+ next_token = torch.multinomial( F.softmax(filtered_logits, dim=-1), num_samples=1 )
41
+ if next_token == tokenizer.sep_token_id:
42
+ break
43
+ generated.append( next_token.item() )
44
+ input_ids = torch.cat( (input_ids, next_token.unsqueeze(0)), dim=1 )
45
+
46
+ return "".join( tokenizer.convert_ids_to_tokens(generated) )
47
+
48
  def generate(input_text):
49
+ result = []
50
+ for i in range(100):
51
+ text = generate0(input_text)
52
+ result.append(text)
53
+ return "".join( result )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if __name__ == "__main__":
56
 
 
58
  fn=generate,
59
  inputs="text",
60
  outputs="text"
61
+ ).launch()