bugfix: Update modeling_t5.T5Stack.forward() for Gradient Checkpointing
Browse filesUpdate checkpoint() call such that parameters for the layer_module object are passed correctly.
- modeling_t5.py +10 -11
modeling_t5.py
CHANGED
@@ -1204,14 +1204,8 @@ class T5Stack(T5PreTrainedModel):
|
|
1204 |
|
1205 |
if self.gradient_checkpointing and self.training:
|
1206 |
|
1207 |
-
def create_custom_forward(module):
|
1208 |
-
def custom_forward(*inputs):
|
1209 |
-
return tuple(module(*inputs, use_cache, output_attentions))
|
1210 |
-
|
1211 |
-
return custom_forward
|
1212 |
-
|
1213 |
layer_outputs = checkpoint(
|
1214 |
-
|
1215 |
hidden_states,
|
1216 |
extended_attention_mask,
|
1217 |
position_bias,
|
@@ -1221,10 +1215,15 @@ class T5Stack(T5PreTrainedModel):
|
|
1221 |
layer_head_mask,
|
1222 |
cross_attn_layer_head_mask,
|
1223 |
None, # past_key_value is always None with gradient checkpointing
|
1224 |
-
|
1225 |
-
|
1226 |
-
|
|
|
|
|
|
|
|
|
1227 |
)
|
|
|
1228 |
else:
|
1229 |
layer_outputs = layer_module(
|
1230 |
hidden_states,
|
@@ -1240,7 +1239,7 @@ class T5Stack(T5PreTrainedModel):
|
|
1240 |
output_attentions=output_attentions,
|
1241 |
relative_position=relative_position,
|
1242 |
sparsity_mask=sparsity_mask,
|
1243 |
-
use_additional_bucket=use_additional_bucket
|
1244 |
)
|
1245 |
|
1246 |
# layer_outputs is a tuple with:
|
|
|
1204 |
|
1205 |
if self.gradient_checkpointing and self.training:
|
1206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1207 |
layer_outputs = checkpoint(
|
1208 |
+
layer_module,
|
1209 |
hidden_states,
|
1210 |
extended_attention_mask,
|
1211 |
position_bias,
|
|
|
1215 |
layer_head_mask,
|
1216 |
cross_attn_layer_head_mask,
|
1217 |
None, # past_key_value is always None with gradient checkpointing
|
1218 |
+
use_cache,
|
1219 |
+
output_attentions,
|
1220 |
+
True, # return_dict is true at training time
|
1221 |
+
relative_position,
|
1222 |
+
sparsity_mask,
|
1223 |
+
use_additional_bucket,
|
1224 |
+
use_reentrant=False
|
1225 |
)
|
1226 |
+
|
1227 |
else:
|
1228 |
layer_outputs = layer_module(
|
1229 |
hidden_states,
|
|
|
1239 |
output_attentions=output_attentions,
|
1240 |
relative_position=relative_position,
|
1241 |
sparsity_mask=sparsity_mask,
|
1242 |
+
use_additional_bucket=use_additional_bucket
|
1243 |
)
|
1244 |
|
1245 |
# layer_outputs is a tuple with:
|