Spaces:
Sleeping
Sleeping
Commit
·
9c1b483
1
Parent(s):
c121a67
special token removed
Browse files
app.py
CHANGED
@@ -32,8 +32,6 @@ model.train(False)
|
|
32 |
|
33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
34 |
enc = tiktoken.get_encoding('gpt2')
|
35 |
-
# Modify encoding behavior to allow special tokens
|
36 |
-
enc._special_tokens.add('<|endoftext|>')
|
37 |
tokens = enc.encode(prompt)
|
38 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
39 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
@@ -55,10 +53,8 @@ def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
|
55 |
|
56 |
tokens = torch.cat((tokens, next_token), dim=1)
|
57 |
|
58 |
-
#
|
59 |
-
|
60 |
-
if next_token.item() == endoftext_token:
|
61 |
-
break
|
62 |
|
63 |
generated_texts = []
|
64 |
for i in range(num_samples):
|
|
|
32 |
|
33 |
def generate_text(prompt, max_length=100, num_samples=1, temperature=0.8):
|
34 |
enc = tiktoken.get_encoding('gpt2')
|
|
|
|
|
35 |
tokens = enc.encode(prompt)
|
36 |
tokens = torch.tensor(tokens, dtype=torch.long)
|
37 |
tokens = tokens.unsqueeze(0).repeat(num_samples, 1)
|
|
|
53 |
|
54 |
tokens = torch.cat((tokens, next_token), dim=1)
|
55 |
|
56 |
+
# Remove special token check entirely
|
57 |
+
# Just generate for the specified length or until context limit
|
|
|
|
|
58 |
|
59 |
generated_texts = []
|
60 |
for i in range(num_samples):
|