Adapt modelling for gradient checkpointing
Browse filesFixed passed parameters to Model and removed the old gradient checkpointing method used in T5Stack as Huggingface deprecated it.
- modeling_t5.py +10 -15
modeling_t5.py
CHANGED
@@ -977,10 +977,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
|
977 |
if module.has_relative_attention_bias:
|
978 |
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
979 |
|
980 |
-
def _set_gradient_checkpointing(self, module, value=False):
|
981 |
-
if isinstance(module, (T5Attention, T5Stack)):
|
982 |
-
module.gradient_checkpointing = value
|
983 |
-
|
984 |
def _shift_right(self, input_ids):
|
985 |
decoder_start_token_id = self.config.decoder_start_token_id
|
986 |
pad_token_id = self.config.pad_token_id
|
@@ -1204,14 +1200,8 @@ class T5Stack(T5PreTrainedModel):
|
|
1204 |
|
1205 |
if self.gradient_checkpointing and self.training:
|
1206 |
|
1207 |
-
def create_custom_forward(module):
|
1208 |
-
def custom_forward(*inputs):
|
1209 |
-
return tuple(module(*inputs, use_cache, output_attentions))
|
1210 |
-
|
1211 |
-
return custom_forward
|
1212 |
-
|
1213 |
layer_outputs = checkpoint(
|
1214 |
-
|
1215 |
hidden_states,
|
1216 |
extended_attention_mask,
|
1217 |
position_bias,
|
@@ -1221,10 +1211,15 @@ class T5Stack(T5PreTrainedModel):
|
|
1221 |
layer_head_mask,
|
1222 |
cross_attn_layer_head_mask,
|
1223 |
None, # past_key_value is always None with gradient checkpointing
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
|
|
|
|
|
|
|
|
1227 |
)
|
|
|
1228 |
else:
|
1229 |
layer_outputs = layer_module(
|
1230 |
hidden_states,
|
@@ -1240,7 +1235,7 @@ class T5Stack(T5PreTrainedModel):
|
|
1240 |
output_attentions=output_attentions,
|
1241 |
relative_position=relative_position,
|
1242 |
sparsity_mask=sparsity_mask,
|
1243 |
-
use_additional_bucket=use_additional_bucket
|
1244 |
)
|
1245 |
|
1246 |
# layer_outputs is a tuple with:
|
|
|
977 |
if module.has_relative_attention_bias:
|
978 |
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
|
979 |
|
|
|
|
|
|
|
|
|
980 |
def _shift_right(self, input_ids):
|
981 |
decoder_start_token_id = self.config.decoder_start_token_id
|
982 |
pad_token_id = self.config.pad_token_id
|
|
|
1200 |
|
1201 |
if self.gradient_checkpointing and self.training:
|
1202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1203 |
layer_outputs = checkpoint(
|
1204 |
+
layer_module,
|
1205 |
hidden_states,
|
1206 |
extended_attention_mask,
|
1207 |
position_bias,
|
|
|
1211 |
layer_head_mask,
|
1212 |
cross_attn_layer_head_mask,
|
1213 |
None, # past_key_value is always None with gradient checkpointing
|
1214 |
+
use_cache,
|
1215 |
+
output_attentions,
|
1216 |
+
True, # return_dict is true at training time
|
1217 |
+
relative_position,
|
1218 |
+
sparsity_mask,
|
1219 |
+
use_additional_bucket,
|
1220 |
+
use_reentrant=False
|
1221 |
)
|
1222 |
+
|
1223 |
else:
|
1224 |
layer_outputs = layer_module(
|
1225 |
hidden_states,
|
|
|
1235 |
output_attentions=output_attentions,
|
1236 |
relative_position=relative_position,
|
1237 |
sparsity_mask=sparsity_mask,
|
1238 |
+
use_additional_bucket=use_additional_bucket
|
1239 |
)
|
1240 |
|
1241 |
# layer_outputs is a tuple with:
|