satyanayak commited on
Commit
c121a67
·
1 Parent(s): e243b3e

|endoftext| token handled

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -31,8 +31,9 @@ model = load_model_from_hf()
31
  model.train(False)
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
- # Initialize encoder with allowed special tokens
35
- enc = tiktoken.get_encoding('gpt2', allowed_special={'<|endoftext|>'})
 
36
  tokens = enc.encode(prompt)
37
  tokens = torch.tensor(tokens, dtype=torch.long)
38
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
@@ -55,7 +56,8 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
55
  tokens = torch.cat((tokens, next_token), dim=1)
56
 
57
  # Check for end of text token
58
- if next_token.item() == enc.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:
 
59
  break
60
 
61
  generated_texts = []
 
31
  model.train(False)
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
+ enc = tiktoken.get_encoding('gpt2')
35
+ # Modify encoding behavior to allow special tokens
36
+ enc._special_tokens.add('<|endoftext|>')
37
  tokens = enc.encode(prompt)
38
  tokens = torch.tensor(tokens, dtype=torch.long)
39
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
 
56
  tokens = torch.cat((tokens, next_token), dim=1)
57
 
58
  # Check for end of text token
59
+ endoftext_token = enc.encode('<|endoftext|>')[0]
60
+ if next_token.item() == endoftext_token:
61
  break
62
 
63
  generated_texts = []