Spaces:
Sleeping
Sleeping
File size: 5,822 Bytes
2260825 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple
from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs
if is_torch_available():
from transformers import AutoModel, PreTrainedModel
FEATURES_TO_AUTOMODELS = {
"default": AutoModel,
}
# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
"albert": {"default": AlbertOnnxConfig.default},
"bart": {"default": BartOnnxConfig.default},
"bert": {"default": BertOnnxConfig.default},
"distilbert": {"default": DistilBertOnnxConfig.default},
"gpt2": {"default": GPT2OnnxConfig.default},
"longformer": {"default": LongformerOnnxConfig.default},
"roberta": {"default": RobertaOnnxConfig},
"t5": {"default": T5OnnxConfig.default},
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}
def get_model_from_features(features: str, model: str):
"""
Attempt to retrieve a model from a model's name and the features to be enabled.
Args:
features: The features required
model: The name of the model to export
Returns:
"""
if features not in FEATURES_TO_AUTOMODELS:
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features
Args:
model: The model to export
features: The name of the features to check if they are avaiable
Returns:
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
"""
if model.config.model_type not in SUPPORTED_MODEL_KIND:
raise KeyError(
f"{model.config.model_type} ({model.name}) is not supported yet. "
f"Only {SUPPORTED_MODEL_KIND} are supported. "
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
)
# Look for the features
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
if features not in model_features:
raise ValueError(
f"{model.config.model_type} doesn't support features {features}. "
f"Supported values are: {list(model_features.keys())}"
)
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
def main():
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
parser.add_argument(
"--features",
choices=["default"],
default="default",
help="Export the model with some additional features.",
)
parser.add_argument(
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
)
parser.add_argument(
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
)
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
# Retrieve CLI arguments
args = parser.parse_args()
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
if not args.output.parent.exists():
args.output.parent.mkdir(parents=True)
# Allocate the model
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = get_model_from_features(args.features, args.model)
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
onnx_config = model_onnx_config(model.config)
# Ensure the requested opset is sufficient
if args.opset < onnx_config.default_onnx_opset:
raise ValueError(
f"Opset {args.opset} is not sufficient to export {model_kind}. "
f"At least {onnx_config.default_onnx_opset} is required."
)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
logger.info(f"All good, model saved at: {args.output.as_posix()}")
if __name__ == "__main__":
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name
logger.setLevel(logging.INFO)
main()
|