Spaces:
Sleeping
Sleeping
# Copyright 2021 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from argparse import ArgumentParser | |
from pathlib import Path | |
from typing import Callable, Tuple | |
from transformers.models.albert import AlbertOnnxConfig | |
from transformers.models.auto import AutoTokenizer | |
from transformers.models.bart import BartOnnxConfig | |
from transformers.models.bert import BertOnnxConfig | |
from transformers.models.distilbert import DistilBertOnnxConfig | |
from transformers.models.gpt2 import GPT2OnnxConfig | |
from transformers.models.longformer import LongformerOnnxConfig | |
from transformers.models.roberta import RobertaOnnxConfig | |
from transformers.models.t5 import T5OnnxConfig | |
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig | |
from .. import is_torch_available | |
from ..utils import logging | |
from .convert import export, validate_model_outputs | |
if is_torch_available(): | |
from transformers import AutoModel, PreTrainedModel | |
FEATURES_TO_AUTOMODELS = { | |
"default": AutoModel, | |
} | |
# Set of model topologies we support associated to the features supported by each topology and the factory | |
SUPPORTED_MODEL_KIND = { | |
"albert": {"default": AlbertOnnxConfig.default}, | |
"bart": {"default": BartOnnxConfig.default}, | |
"bert": {"default": BertOnnxConfig.default}, | |
"distilbert": {"default": DistilBertOnnxConfig.default}, | |
"gpt2": {"default": GPT2OnnxConfig.default}, | |
"longformer": {"default": LongformerOnnxConfig.default}, | |
"roberta": {"default": RobertaOnnxConfig}, | |
"t5": {"default": T5OnnxConfig.default}, | |
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default}, | |
} | |
def get_model_from_features(features: str, model: str): | |
""" | |
Attempt to retrieve a model from a model's name and the features to be enabled. | |
Args: | |
features: The features required | |
model: The name of the model to export | |
Returns: | |
""" | |
if features not in FEATURES_TO_AUTOMODELS: | |
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}") | |
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model) | |
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]: | |
""" | |
Check whether or not the model has the requested features | |
Args: | |
model: The model to export | |
features: The name of the features to check if they are avaiable | |
Returns: | |
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties | |
""" | |
if model.config.model_type not in SUPPORTED_MODEL_KIND: | |
raise KeyError( | |
f"{model.config.model_type} ({model.name}) is not supported yet. " | |
f"Only {SUPPORTED_MODEL_KIND} are supported. " | |
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." | |
) | |
# Look for the features | |
model_features = SUPPORTED_MODEL_KIND[model.config.model_type] | |
if features not in model_features: | |
raise ValueError( | |
f"{model.config.model_type} doesn't support features {features}. " | |
f"Supported values are: {list(model_features.keys())}" | |
) | |
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features] | |
def main(): | |
parser = ArgumentParser("Hugging Face ONNX Exporter tool") | |
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.") | |
parser.add_argument( | |
"--features", | |
choices=["default"], | |
default="default", | |
help="Export the model with some additional features.", | |
) | |
parser.add_argument( | |
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." | |
) | |
parser.add_argument( | |
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model." | |
) | |
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") | |
# Retrieve CLI arguments | |
args = parser.parse_args() | |
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx") | |
if not args.output.parent.exists(): | |
args.output.parent.mkdir(parents=True) | |
# Allocate the model | |
tokenizer = AutoTokenizer.from_pretrained(args.model) | |
model = get_model_from_features(args.features, args.model) | |
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features) | |
onnx_config = model_onnx_config(model.config) | |
# Ensure the requested opset is sufficient | |
if args.opset < onnx_config.default_onnx_opset: | |
raise ValueError( | |
f"Opset {args.opset} is not sufficient to export {model_kind}. " | |
f"At least {onnx_config.default_onnx_opset} is required." | |
) | |
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) | |
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) | |
logger.info(f"All good, model saved at: {args.output.as_posix()}") | |
if __name__ == "__main__": | |
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name | |
logger.setLevel(logging.INFO) | |
main() | |