Spaces:
Runtime error
Runtime error
Update min_dalle/models/dalle_bart_decoder.py
Browse files
min_dalle/models/dalle_bart_decoder.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Tuple, List
|
|
2 |
import torch
|
3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
|
|
5 |
|
6 |
IMAGE_TOKEN_COUNT = 256
|
7 |
|
@@ -100,6 +101,8 @@ class DecoderLayer(nn.Module):
|
|
100 |
decoder_state = self.glu.forward(decoder_state)
|
101 |
decoder_state = residual + decoder_state
|
102 |
|
|
|
|
|
103 |
return decoder_state, attention_state
|
104 |
|
105 |
|
@@ -170,6 +173,7 @@ class DalleBartDecoder(nn.Module):
|
|
170 |
logits[:image_count] * (1 - supercondition_factor) +
|
171 |
logits[image_count:] * supercondition_factor
|
172 |
)
|
|
|
173 |
logits_sorted, _ = logits.sort(descending=True)
|
174 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
175 |
del top_k
|
@@ -179,7 +183,9 @@ class DalleBartDecoder(nn.Module):
|
|
179 |
del temperature
|
180 |
logits.exp_()
|
181 |
logits *= is_kept.to(torch.float32)
|
|
|
182 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
183 |
del logits
|
|
|
184 |
|
185 |
return image_tokens, attention_state
|
|
|
2 |
import torch
|
3 |
from torch import nn, LongTensor, FloatTensor, BoolTensor
|
4 |
from .dalle_bart_encoder import GLU, AttentionBase
|
5 |
+
import gc
|
6 |
|
7 |
IMAGE_TOKEN_COUNT = 256
|
8 |
|
|
|
101 |
decoder_state = self.glu.forward(decoder_state)
|
102 |
decoder_state = residual + decoder_state
|
103 |
|
104 |
+
|
105 |
+
|
106 |
return decoder_state, attention_state
|
107 |
|
108 |
|
|
|
173 |
logits[:image_count] * (1 - supercondition_factor) +
|
174 |
logits[image_count:] * supercondition_factor
|
175 |
)
|
176 |
+
del supercondition_factor
|
177 |
logits_sorted, _ = logits.sort(descending=True)
|
178 |
is_kept = logits >= logits_sorted[:, top_k - 1]
|
179 |
del top_k
|
|
|
183 |
del temperature
|
184 |
logits.exp_()
|
185 |
logits *= is_kept.to(torch.float32)
|
186 |
+
del is_kept
|
187 |
image_tokens = torch.multinomial(logits, 1)[:, 0]
|
188 |
del logits
|
189 |
+
gc.collect()
|
190 |
|
191 |
return image_tokens, attention_state
|