satyanayak commited on
Commit
9c1b483
·
1 Parent(s): c121a67

special token removed

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -32,8 +32,6 @@ model.train(False)
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
  enc = tiktoken.get_encoding('gpt2')
35
- # Modify encoding behavior to allow special tokens
36
- enc._special_tokens.add('<|endoftext|>')
37
  tokens = enc.encode(prompt)
38
  tokens = torch.tensor(tokens, dtype=torch.long)
39
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
@@ -55,10 +53,8 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
55
 
56
  tokens = torch.cat((tokens, next_token), dim=1)
57
 
58
- # Check for end of text token
59
- endoftext_token = enc.encode('<|endoftext|>')[0]
60
- if next_token.item() == endoftext_token:
61
- break
62
 
63
  generated_texts = []
64
  for i in range(num_samples):
 
32
 
33
  def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
34
  enc = tiktoken.get_encoding('gpt2')
 
 
35
  tokens = enc.encode(prompt)
36
  tokens = torch.tensor(tokens, dtype=torch.long)
37
  tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
 
53
 
54
  tokens = torch.cat((tokens, next_token), dim=1)
55
 
56
+ # Remove special token check entirely
57
+ # Just generate for the specified length or until context limit
 
 
58
 
59
  generated_texts = []
60
  for i in range(num_samples):