JMalott commited on
Commit
ca509e8
·
1 Parent(s): 602c80d

Update min_dalle/models/dalle_bart_decoder.py

Browse files
min_dalle/models/dalle_bart_decoder.py CHANGED
@@ -2,6 +2,7 @@ from typing import Tuple, List
2
  import torch
3
  from torch import nn, LongTensor, FloatTensor, BoolTensor
4
  from .dalle_bart_encoder import GLU, AttentionBase
 
5
 
6
  IMAGE_TOKEN_COUNT = 256
7
 
@@ -100,6 +101,8 @@ class DecoderLayer(nn.Module):
100
  decoder_state = self.glu.forward(decoder_state)
101
  decoder_state = residual + decoder_state
102
 
 
 
103
  return decoder_state, attention_state
104
 
105
 
@@ -170,6 +173,7 @@ class DalleBartDecoder(nn.Module):
170
  logits[:image_count] * (1 - supercondition_factor) +
171
  logits[image_count:] * supercondition_factor
172
  )
 
173
  logits_sorted, _ = logits.sort(descending=True)
174
  is_kept = logits >= logits_sorted[:, top_k - 1]
175
  del top_k
@@ -179,7 +183,9 @@ class DalleBartDecoder(nn.Module):
179
  del temperature
180
  logits.exp_()
181
  logits *= is_kept.to(torch.float32)
 
182
  image_tokens = torch.multinomial(logits, 1)[:, 0]
183
  del logits
 
184
 
185
  return image_tokens, attention_state
 
2
  import torch
3
  from torch import nn, LongTensor, FloatTensor, BoolTensor
4
  from .dalle_bart_encoder import GLU, AttentionBase
5
+ import gc
6
 
7
  IMAGE_TOKEN_COUNT = 256
8
 
 
101
  decoder_state = self.glu.forward(decoder_state)
102
  decoder_state = residual + decoder_state
103
 
104
+
105
+
106
  return decoder_state, attention_state
107
 
108
 
 
173
  logits[:image_count] * (1 - supercondition_factor) +
174
  logits[image_count:] * supercondition_factor
175
  )
176
+ del supercondition_factor
177
  logits_sorted, _ = logits.sort(descending=True)
178
  is_kept = logits >= logits_sorted[:, top_k - 1]
179
  del top_k
 
183
  del temperature
184
  logits.exp_()
185
  logits *= is_kept.to(torch.float32)
186
+ del is_kept
187
  image_tokens = torch.multinomial(logits, 1)[:, 0]
188
  del logits
189
+ gc.collect()
190
 
191
  return image_tokens, attention_state