Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
@@ -161,6 +161,7 @@ class DalleBartDecoder(nn.Module):
|
|
161 |
)
|
162 |
decoder_state = self.final_ln(decoder_state)
|
163 |
logits = self.lm_head(decoder_state)
|
|
|
164 |
temperature = settings[[0]]
|
165 |
top_k = settings[[1]].to(torch.long)
|
166 |
supercondition_factor = settings[[2]]
|
@@ -171,11 +172,13 @@ class DalleBartDecoder(nn.Module):
|
|
171 |
)
|
172 |
logits_sorted, _ = logits.sort(descending=True)
|
173 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
|
|
174 |
logits -= logits_sorted[:, [0]]
|
|
|
175 |
logits /= temperature
|
176 |
logits.exp_()
|
177 |
logits *= is_kept.to(torch.float32)
|
178 |
-
|
179 |
-
|
180 |
|
181 |
-
return
|
|
|
161 |
)
|
162 |
decoder_state = self.final_ln(decoder_state)
|
163 |
logits = self.lm_head(decoder_state)
|
164 |
+
del decorder_state
|
165 |
temperature = settings[[0]]
|
166 |
top_k = settings[[1]].to(torch.long)
|
167 |
supercondition_factor = settings[[2]]
|
|
|
172 |
)
|
173 |
logits_sorted, _ = logits.sort(descending=True)
|
174 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
175 |
+
del top_k
|
176 |
logits -= logits_sorted[:, [0]]
|
177 |
+
del logits_sorted
|
178 |
logits /= temperature
|
179 |
logits.exp_()
|
180 |
logits *= is_kept.to(torch.float32)
|
181 |
+
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
182 |
+
del logits
|
183 |
|
184 |
+
return image_tokens, attention_state
|