Adapt modelling for gradient checkpointing

#3
Files changed (1) hide show
  1. modeling_t5.py +10 -15
modeling_t5.py CHANGED
@@ -977,10 +977,6 @@ class T5PreTrainedModel(PreTrainedModel):
977
  if module.has_relative_attention_bias:
978
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
979
 
980
- def _set_gradient_checkpointing(self, module, value=False):
981
- if isinstance(module, (T5Attention, T5Stack)):
982
- module.gradient_checkpointing = value
983
-
984
  def _shift_right(self, input_ids):
985
  decoder_start_token_id = self.config.decoder_start_token_id
986
  pad_token_id = self.config.pad_token_id
@@ -1204,14 +1200,8 @@ class T5Stack(T5PreTrainedModel):
1204
 
1205
  if self.gradient_checkpointing and self.training:
1206
 
1207
- def create_custom_forward(module):
1208
- def custom_forward(*inputs):
1209
- return tuple(module(*inputs, use_cache, output_attentions))
1210
-
1211
- return custom_forward
1212
-
1213
  layer_outputs = checkpoint(
1214
- create_custom_forward(layer_module),
1215
  hidden_states,
1216
  extended_attention_mask,
1217
  position_bias,
@@ -1221,10 +1211,15 @@ class T5Stack(T5PreTrainedModel):
1221
  layer_head_mask,
1222
  cross_attn_layer_head_mask,
1223
  None, # past_key_value is always None with gradient checkpointing
1224
- relative_position=relative_position,
1225
- sparsity_mask=sparsity_mask,
1226
- use_additional_bucket=use_additional_bucket,
 
 
 
 
1227
  )
 
1228
  else:
1229
  layer_outputs = layer_module(
1230
  hidden_states,
@@ -1240,7 +1235,7 @@ class T5Stack(T5PreTrainedModel):
1240
  output_attentions=output_attentions,
1241
  relative_position=relative_position,
1242
  sparsity_mask=sparsity_mask,
1243
- use_additional_bucket=use_additional_bucket,
1244
  )
1245
 
1246
  # layer_outputs is a tuple with:
 
977
  if module.has_relative_attention_bias:
978
  module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
979
 
 
 
 
 
980
  def _shift_right(self, input_ids):
981
  decoder_start_token_id = self.config.decoder_start_token_id
982
  pad_token_id = self.config.pad_token_id
 
1200
 
1201
  if self.gradient_checkpointing and self.training:
1202
 
 
 
 
 
 
 
1203
  layer_outputs = checkpoint(
1204
+ layer_module,
1205
  hidden_states,
1206
  extended_attention_mask,
1207
  position_bias,
 
1211
  layer_head_mask,
1212
  cross_attn_layer_head_mask,
1213
  None, # past_key_value is always None with gradient checkpointing
1214
+ use_cache,
1215
+ output_attentions,
1216
+ True, # return_dict is true at training time
1217
+ relative_position,
1218
+ sparsity_mask,
1219
+ use_additional_bucket,
1220
+ use_reentrant=False
1221
  )
1222
+
1223
  else:
1224
  layer_outputs = layer_module(
1225
  hidden_states,
 
1235
  output_attentions=output_attentions,
1236
  relative_position=relative_position,
1237
  sparsity_mask=sparsity_mask,
1238
+ use_additional_bucket=use_additional_bucket
1239
  )
1240
 
1241
  # layer_outputs is a tuple with: