Spaces:
Runtime error
Runtime error
Update generate.py
Browse files- generate.py +5 -5
generate.py
CHANGED
@@ -86,10 +86,10 @@ class LmGeneration:
|
|
86 |
total_len = args.seq_length
|
87 |
|
88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
-
tokens = torch.full((batch, total_len), self.tokenizer.
|
90 |
for idx, t in enumerate(prompt_tokens):
|
91 |
tokens[idx, : len(t)] = torch.tensor(t).long()
|
92 |
-
mask = tokens != self.tokenizer.
|
93 |
start_pos = min_prompt_len
|
94 |
prev_pos = 0
|
95 |
continue_exsample = [i for i in range(batch)]
|
@@ -118,7 +118,7 @@ class LmGeneration:
|
|
118 |
continue_exsample = []
|
119 |
for i, t in enumerate(tokens.tolist()):
|
120 |
try:
|
121 |
-
t.index(self.tokenizer.
|
122 |
except ValueError:
|
123 |
if cut_off is not None:
|
124 |
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
@@ -134,8 +134,8 @@ class LmGeneration:
|
|
134 |
for i, t in enumerate(tokens.tolist()):
|
135 |
t = t[: args.seq_length]
|
136 |
try:
|
137 |
-
t = t[: t.index(self.tokenizer.
|
138 |
-
t = t[: t.index(self.tokenizer.
|
139 |
except ValueError:
|
140 |
pass
|
141 |
decoder.append(self.tokenizer.decode(t))
|
|
|
86 |
total_len = args.seq_length
|
87 |
|
88 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
89 |
+
tokens = torch.full((batch, total_len), self.tokenizer.pad_token_id).to(device).long()
|
90 |
for idx, t in enumerate(prompt_tokens):
|
91 |
tokens[idx, : len(t)] = torch.tensor(t).long()
|
92 |
+
mask = tokens != self.tokenizer.pad_token_id
|
93 |
start_pos = min_prompt_len
|
94 |
prev_pos = 0
|
95 |
continue_exsample = [i for i in range(batch)]
|
|
|
118 |
continue_exsample = []
|
119 |
for i, t in enumerate(tokens.tolist()):
|
120 |
try:
|
121 |
+
t.index(self.tokenizer.eos_token_id)
|
122 |
except ValueError:
|
123 |
if cut_off is not None:
|
124 |
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
|
|
134 |
for i, t in enumerate(tokens.tolist()):
|
135 |
t = t[: args.seq_length]
|
136 |
try:
|
137 |
+
t = t[: t.index(self.tokenizer.pad_token_id)]
|
138 |
+
t = t[: t.index(self.tokenizer.eos_token_id)]
|
139 |
except ValueError:
|
140 |
pass
|
141 |
decoder.append(self.tokenizer.decode(t))
|