Markus28 commited on
Commit
75d7a16
·
1 Parent(s): 5b58f09

feat: support gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_bert.py +16 -0
modeling_bert.py CHANGED
@@ -154,6 +154,17 @@ class BertEncoder(nn.Module):
154
  self.layers = nn.ModuleList(
155
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
156
  )
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
159
  """If subset_mask is not None, we only want output for the subset of the sequence.
@@ -298,6 +309,11 @@ class BertPreTrainedModel(PreTrainedModel):
298
  """
299
  config_class = JinaBertConfig
300
  base_model_prefix = "bert"
 
 
 
 
 
301
 
302
 
303
  class BertModel(BertPreTrainedModel):
 
154
  self.layers = nn.ModuleList(
155
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
156
  )
157
+ self._grad_checkpointing = False
158
+
159
+ @property
160
+ def gradient_checkpointing(self):
161
+ return self._grad_checkpointing
162
+
163
+ @gradient_checkpointing.setter
164
+ def gradient_checkpointing(self, value):
165
+ self._grad_checkpointing = value
166
+ for block in self.layers:
167
+ block.mixer.checkpointing = value
168
 
169
  def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
170
  """If subset_mask is not None, we only want output for the subset of the sequence.
 
309
  """
310
  config_class = JinaBertConfig
311
  base_model_prefix = "bert"
312
+ supports_gradient_checkpointing = True
313
+
314
+ def _set_gradient_checkpointing(self, module, value=False):
315
+ if isinstance(module, BertEncoder):
316
+ module.gradient_checkpointing = value
317
 
318
 
319
  class BertModel(BertPreTrainedModel):