Update model.py
Browse files
model.py
CHANGED
|
@@ -381,7 +381,11 @@ class StripedHyena(nn.Module):
|
|
| 381 |
x = x * padding_mask[..., None]
|
| 382 |
|
| 383 |
for _, block in enumerate(self.blocks):
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
return x, None
|
| 386 |
|
| 387 |
def initialize_inference_params(self):
|
|
|
|
| 381 |
x = x * padding_mask[..., None]
|
| 382 |
|
| 383 |
for _, block in enumerate(self.blocks):
|
| 384 |
+
if self.gradient_checkpointing and self.training:
|
| 385 |
+
x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
|
| 386 |
+
else:
|
| 387 |
+
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
|
| 388 |
+
|
| 389 |
return x, None
|
| 390 |
|
| 391 |
def initialize_inference_params(self):
|