dayyass's picture
Update README.md
891847a verified
|
raw
history blame
1.27 kB
metadata
license: mit

Convert MUSE from TensorFlow to PyTorch and ONNX

Read more about the project: GitHub.

The PyTorch model can be used not only for inference, but also for additional training and fine-tuning.

Usage

The model is available in HF Models directly through torch (currently, without native support from the transformers library).

Model initialization and usage code:

import torch
from functools import partial
from src.architecture import MUSE
from src.tokenizer import get_tokenizer, tokenize

PATH_TO_PT_MODEL = "model.pt"
PATH_TO_TF_MODEL = "universal-sentence-encoder-multilingual-large-3"

tokenizer = get_tokenizer(PATH_TO_TF_MODEL)
tokenize = partial(tokenize, tokenizer=tokenizer)

model_torch = MUSE(
    num_embeddings=128010,
    embedding_dim=512,
    d_model=512,
    num_heads=8,
)
model_torch.load_state_dict(
    torch.load(PATH_TO_PT_MODEL)
)

sentence = "Hello, world!"
res = model_torch(tokenize(sentence))

Currently, the checkpoint of the original TF Hub model is used for tokenization, so it is loaded in the code above.