587b6c9 04a69d4 587b6c9 04a69d4 587b6c9
1
2
3
4
5
6
7
8
9
import pickle import jax dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb")) del dic["optim_state_dict"] dic = jax.device_get(dic) pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))