Panda-vid commited on
Commit
0d0e83a
·
verified ·
1 Parent(s): 8cc4283

bugfix: Update modeling_t5.T5Stack.forward() for Gradient Checkpointing

Browse files

Update checkpoint() call such that parameters for the layer_module object are passed correctly.

Files changed (1) hide show
  1. modeling_t5.py +10 -11
modeling_t5.py CHANGED
@@ -1204,14 +1204,8 @@ class T5Stack(T5PreTrainedModel):
1204
 
1205
  if self.gradient_checkpointing and self.training:
1206
 
1207
- def create_custom_forward(module):
1208
- def custom_forward(*inputs):
1209
- return tuple(module(*inputs, use_cache, output_attentions))
1210
-
1211
- return custom_forward
1212
-
1213
  layer_outputs = checkpoint(
1214
- create_custom_forward(layer_module),
1215
  hidden_states,
1216
  extended_attention_mask,
1217
  position_bias,
@@ -1221,10 +1215,15 @@ class T5Stack(T5PreTrainedModel):
1221
  layer_head_mask,
1222
  cross_attn_layer_head_mask,
1223
  None, # past_key_value is always None with gradient checkpointing
1224
- relative_position=relative_position,
1225
- sparsity_mask=sparsity_mask,
1226
- use_additional_bucket=use_additional_bucket,
 
 
 
 
1227
  )
 
1228
  else:
1229
  layer_outputs = layer_module(
1230
  hidden_states,
@@ -1240,7 +1239,7 @@ class T5Stack(T5PreTrainedModel):
1240
  output_attentions=output_attentions,
1241
  relative_position=relative_position,
1242
  sparsity_mask=sparsity_mask,
1243
- use_additional_bucket=use_additional_bucket,
1244
  )
1245
 
1246
  # layer_outputs is a tuple with:
 
1204
 
1205
  if self.gradient_checkpointing and self.training:
1206
 
 
 
 
 
 
 
1207
  layer_outputs = checkpoint(
1208
+ layer_module,
1209
  hidden_states,
1210
  extended_attention_mask,
1211
  position_bias,
 
1215
  layer_head_mask,
1216
  cross_attn_layer_head_mask,
1217
  None, # past_key_value is always None with gradient checkpointing
1218
+ use_cache,
1219
+ output_attentions,
1220
+ True, # return_dict is true at training time
1221
+ relative_position,
1222
+ sparsity_mask,
1223
+ use_additional_bucket,
1224
+ use_reentrant=False
1225
  )
1226
+
1227
  else:
1228
  layer_outputs = layer_module(
1229
  hidden_states,
 
1239
  output_attentions=output_attentions,
1240
  relative_position=relative_position,
1241
  sparsity_mask=sparsity_mask,
1242
+ use_additional_bucket=use_additional_bucket
1243
  )
1244
 
1245
  # layer_outputs is a tuple with: