Panda-vid commited on
Commit
162d434
·
verified ·
1 Parent(s): 8cc4283

Adapt modelling for gradient checkpointing

Browse files

Fixed passed parameters to Model and removed the old gradient checkpointing method used in T5Stack as Huggingface deprecated it.

Files changed (1) hide show
  1. modeling_t5.py +10 -15
modeling_t5.py CHANGED
@@ -977,10 +977,6 @@ class T5PreTrainedModel(PreTrainedModel):
977
  if module.has_relative_attention_bias:
978
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
979
 
980
- def _set_gradient_checkpointing(self, module, value=False):
981
- if isinstance(module, (T5Attention, T5Stack)):
982
- module.gradient_checkpointing = value
983
-
984
  def _shift_right(self, input_ids):
985
  decoder_start_token_id = self.config.decoder_start_token_id
986
  pad_token_id = self.config.pad_token_id
@@ -1204,14 +1200,8 @@ class T5Stack(T5PreTrainedModel):
1204
 
1205
  if self.gradient_checkpointing and self.training:
1206
 
1207
- def create_custom_forward(module):
1208
- def custom_forward(*inputs):
1209
- return tuple(module(*inputs, use_cache, output_attentions))
1210
-
1211
- return custom_forward
1212
-
1213
  layer_outputs = checkpoint(
1214
- create_custom_forward(layer_module),
1215
  hidden_states,
1216
  extended_attention_mask,
1217
  position_bias,
@@ -1221,10 +1211,15 @@ class T5Stack(T5PreTrainedModel):
1221
  layer_head_mask,
1222
  cross_attn_layer_head_mask,
1223
  None, # past_key_value is always None with gradient checkpointing
1224
- relative_position=relative_position,
1225
- sparsity_mask=sparsity_mask,
1226
- use_additional_bucket=use_additional_bucket,
 
 
 
 
1227
  )
 
1228
  else:
1229
  layer_outputs = layer_module(
1230
  hidden_states,
@@ -1240,7 +1235,7 @@ class T5Stack(T5PreTrainedModel):
1240
  output_attentions=output_attentions,
1241
  relative_position=relative_position,
1242
  sparsity_mask=sparsity_mask,
1243
- use_additional_bucket=use_additional_bucket,
1244
  )
1245
 
1246
  # layer_outputs is a tuple with:
 
977
  if module.has_relative_attention_bias:
978
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
979
 
 
 
 
 
980
  def _shift_right(self, input_ids):
981
  decoder_start_token_id = self.config.decoder_start_token_id
982
  pad_token_id = self.config.pad_token_id
 
1200
 
1201
  if self.gradient_checkpointing and self.training:
1202
 
 
 
 
 
 
 
1203
  layer_outputs = checkpoint(
1204
+ layer_module,
1205
  hidden_states,
1206
  extended_attention_mask,
1207
  position_bias,
 
1211
  layer_head_mask,
1212
  cross_attn_layer_head_mask,
1213
  None, # past_key_value is always None with gradient checkpointing
1214
+ use_cache,
1215
+ output_attentions,
1216
+ True, # return_dict is true at training time
1217
+ relative_position,
1218
+ sparsity_mask,
1219
+ use_additional_bucket,
1220
+ use_reentrant=False
1221
  )
1222
+
1223
  else:
1224
  layer_outputs = layer_module(
1225
  hidden_states,
 
1235
  output_attentions=output_attentions,
1236
  relative_position=relative_position,
1237
  sparsity_mask=sparsity_mask,
1238
+ use_additional_bucket=use_additional_bucket
1239
  )
1240
 
1241
  # layer_outputs is a tuple with: