mrfakename commited on
Commit
e91a892
·
verified ·
1 Parent(s): a8561bd

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/model/backbones/dit.py +3 -1
src/f5_tts/model/backbones/dit.py CHANGED
@@ -219,7 +219,9 @@ class DiT(nn.Module):
219
 
220
  for block in self.transformer_blocks:
221
  if self.checkpoint_activations:
222
- x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
 
 
223
  else:
224
  x = block(x, t, mask=mask, rope=rope)
225
 
 
219
 
220
  for block in self.transformer_blocks:
221
  if self.checkpoint_activations:
222
+ # if you have question, please check: https://pytorch.org/docs/stable/checkpoint.html#torch.utils.checkpoint.checkpoint
223
+ # After PyTorch 2.4, we must pass the use_reentrant explicitly
224
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
225
  else:
226
  x = block(x, t, mask=mask, rope=rope)
227