JMalott commited on
Commit
9055def
·
1 Parent(s): 563067a

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -161,6 +161,7 @@ class DalleBartDecoder(nn.Module):
161
  )
162
  decoder_state = self.final_ln(decoder_state)
163
  logits = self.lm_head(decoder_state)
 
164
  temperature = settings[[0]]
165
  top_k = settings[[1]].to(torch.long)
166
  supercondition_factor = settings[[2]]
@@ -171,11 +172,13 @@ class DalleBartDecoder(nn.Module):
171
  )
172
  logits_sorted, _ = logits.sort(descending=True)
173
  is_kept = logits >= logits_sorted[:, top_k - 1]
 
174
  logits -= logits_sorted[:, [0]]
 
175
  logits /= temperature
176
  logits.exp_()
177
  logits *= is_kept.to(torch.float32)
178
- #image_tokens = torch.multinomial(logits, 1)[:, 0]
179
- #del logits
180
 
181
- return torch.multinomial(logits, 1)[:, 0], attention_state
 
161
  )
162
  decoder_state = self.final_ln(decoder_state)
163
  logits = self.lm_head(decoder_state)
164
+ del decorder_state
165
  temperature = settings[[0]]
166
  top_k = settings[[1]].to(torch.long)
167
  supercondition_factor = settings[[2]]
 
172
  )
173
  logits_sorted, _ = logits.sort(descending=True)
174
  is_kept = logits >= logits_sorted[:, top_k - 1]
175
+ del top_k
176
  logits -= logits_sorted[:, [0]]
177
+ del logits_sorted
178
  logits /= temperature
179
  logits.exp_()
180
  logits *= is_kept.to(torch.float32)
181
+ image_tokens = torch.multinomial(logits, 1)[:, 0]
182
+ del logits
183
 
184
+ return image_tokens, attention_state