Specify from_tf=True to convert a checkpoint from TensorFlow to PyTorch: | |
pt_model = DistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_tf=True) | |
pt_model.save_pretrained("path/to/awesome-name-you-picked") | |
`` | |
</pt> | |
<tf> | |
Specifyfrom_pt=True` to convert a checkpoint from PyTorch to TensorFlow: | |
tf_model = TFDistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_pt=True) | |
Then you can save your new TensorFlow model with its new checkpoint: | |
tf_model.save_pretrained("path/to/awesome-name-you-picked") | |
If a model is available in Flax, you can also convert a checkpoint from PyTorch to Flax: | |
flax_model = FlaxDistilBertForSequenceClassification.from_pretrained( | |
"path/to/awesome-name-you-picked", from_pt=True | |
) | |
Push a model during training | |
Sharing a model to the Hub is as simple as adding an extra parameter or callback. |