Spaces:
Build error
Build error
Commit
·
c121a67
1
Parent(s):
e243b3e
|endoftext| token handled
Browse files
app.py
CHANGED
@@ -31,8 +31,9 @@ model = load_model_from_hf()
|
|
31 |
model.train(False)
|
32 |
|
33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
34 |
-
|
35 |
-
|
|
|
36 |
tokens = enc.encode(prompt)
|
37 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
38 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
@@ -55,7 +56,8 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
|
55 |
tokens = torch.cat((tokens, next_token), dim=1)
|
56 |
|
57 |
# Check for end of text token
|
58 |
-
|
|
|
59 |
break
|
60 |
|
61 |
generated_texts = []
|
|
|
31 |
model.train(False)
|
32 |
|
33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
34 |
+
enc = tiktoken.get_encoding('gpt2')
|
35 |
+
# Modify encoding behavior to allow special tokens
|
36 |
+
enc._special_tokens.add('<|endoftext|>')
|
37 |
tokens = enc.encode(prompt)
|
38 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
39 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
|
56 |
tokens = torch.cat((tokens, next_token), dim=1)
|
57 |
|
58 |
# Check for end of text token
|
59 |
+
endoftext_token = enc.encode('<|endoftext|>')[0]
|
60 |
+
if next_token.item() == endoftext_token:
|
61 |
break
|
62 |
|
63 |
generated_texts = []
|