File size: 908 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
Specify from_tf=True to convert a checkpoint from TensorFlow to PyTorch: pt_model = DistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_tf=True) pt_model.save_pretrained("path/to/awesome-name-you-picked") `` </pt> <tf> Specifyfrom_pt=True` to convert a checkpoint from PyTorch to TensorFlow: tf_model = TFDistilBertForSequenceClassification.from_pretrained("path/to/awesome-name-you-picked", from_pt=True) Then you can save your new TensorFlow model with its new checkpoint: tf_model.save_pretrained("path/to/awesome-name-you-picked") If a model is available in Flax, you can also convert a checkpoint from PyTorch to Flax: flax_model = FlaxDistilBertForSequenceClassification.from_pretrained( "path/to/awesome-name-you-picked", from_pt=True ) Push a model during training Sharing a model to the Hub is as simple as adding an extra parameter or callback. |