Kaizouku's picture
Upload 564 files
2260825 verified
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
default="default",
help="Export the model with some additional features.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
)
parser.add_argument(
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
# Retrieve CLI arguments
args = parser.parse_args()
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient
if args.opset < onnx_config.default_onnx_opset:
raise ValueError(
f"Opset {args.opset} is not sufficient to export {model_kind}. "
f"At least {onnx_config.default_onnx_opset} is required."
)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")
if __name__ == "__main__":
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name
logger.setLevel(logging.INFO)
main()