Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
@@ -169,11 +169,14 @@ class DalleBartDecoder(nn.Module):
|
|
169 |
attention_mask,
|
170 |
token_index
|
171 |
)
|
|
|
172 |
decoder_state = self.final_ln(decoder_state)
|
173 |
logits = self.lm_head(decoder_state)
|
|
|
174 |
del decoder_state
|
175 |
temperature = settings[[0]]
|
176 |
top_k = settings[[1]].to(torch.long)
|
|
|
177 |
supercondition_factor = settings[[2]]
|
178 |
logits = logits[:, -1, : 2 ** 14]
|
179 |
logits: FloatTensor = (
|
|
|
169 |
attention_mask,
|
170 |
token_index
|
171 |
)
|
172 |
+
print(tracemalloc.get_traced_memory())
|
173 |
decoder_state = self.final_ln(decoder_state)
|
174 |
logits = self.lm_head(decoder_state)
|
175 |
+
print(tracemalloc.get_traced_memory())
|
176 |
del decoder_state
|
177 |
temperature = settings[[0]]
|
178 |
top_k = settings[[1]].to(torch.long)
|
179 |
+
print(tracemalloc.get_traced_memory())
|
180 |
supercondition_factor = settings[[2]]
|
181 |
logits = logits[:, -1, : 2 ** 14]
|
182 |
logits: FloatTensor = (
|