JMalott commited on
Commit
027487b
·
1 Parent(s): 3ff83b9

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -169,11 +169,14 @@ class DalleBartDecoder(nn.Module):
169
  attention_mask,
170
  token_index
171
  )
 
172
  decoder_state = self.final_ln(decoder_state)
173
  logits = self.lm_head(decoder_state)
 
174
  del decoder_state
175
  temperature = settings[[0]]
176
  top_k = settings[[1]].to(torch.long)
 
177
  supercondition_factor = settings[[2]]
178
  logits = logits[:, -1, : 2 ** 14]
179
  logits: FloatTensor = (
 
169
  attention_mask,
170
  token_index
171
  )
172
+ print(tracemalloc.get_traced_memory())
173
  decoder_state = self.final_ln(decoder_state)
174
  logits = self.lm_head(decoder_state)
175
+ print(tracemalloc.get_traced_memory())
176
  del decoder_state
177
  temperature = settings[[0]]
178
  top_k = settings[[1]].to(torch.long)
179
+ print(tracemalloc.get_traced_memory())
180
  supercondition_factor = settings[[2]]
181
  logits = logits[:, -1, : 2 ** 14]
182
  logits: FloatTensor = (