File size: 5,822 Bytes
2260825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentParser
from pathlib import Path
from typing import Callable, Tuple

from transformers.models.albert import AlbertOnnxConfig
from transformers.models.auto import AutoTokenizer
from transformers.models.bart import BartOnnxConfig
from transformers.models.bert import BertOnnxConfig
from transformers.models.distilbert import DistilBertOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
from transformers.models.t5 import T5OnnxConfig
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig

from .. import is_torch_available
from ..utils import logging
from .convert import export, validate_model_outputs


if is_torch_available():
    from transformers import AutoModel, PreTrainedModel

    FEATURES_TO_AUTOMODELS = {
        "default": AutoModel,
    }


# Set of model topologies we support associated to the features supported by each topology and the factory
SUPPORTED_MODEL_KIND = {
    "albert": {"default": AlbertOnnxConfig.default},
    "bart": {"default": BartOnnxConfig.default},
    "bert": {"default": BertOnnxConfig.default},
    "distilbert": {"default": DistilBertOnnxConfig.default},
    "gpt2": {"default": GPT2OnnxConfig.default},
    "longformer": {"default": LongformerOnnxConfig.default},
    "roberta": {"default": RobertaOnnxConfig},
    "t5": {"default": T5OnnxConfig.default},
    "xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
}


def get_model_from_features(features: str, model: str):
    """
    Attempt to retrieve a model from a model's name and the features to be enabled.

    Args:
        features: The features required
        model: The name of the model to export

    Returns:

    """
    if features not in FEATURES_TO_AUTOMODELS:
        raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")

    return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)


def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
    """
    Check whether or not the model has the requested features

    Args:
        model: The model to export
        features: The name of the features to check if they are avaiable

    Returns:
        (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties

    """
    if model.config.model_type not in SUPPORTED_MODEL_KIND:
        raise KeyError(
            f"{model.config.model_type} ({model.name}) is not supported yet. "
            f"Only {SUPPORTED_MODEL_KIND} are supported. "
            f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
        )

    # Look for the features
    model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
    if features not in model_features:
        raise ValueError(
            f"{model.config.model_type} doesn't support features {features}. "
            f"Supported values are: {list(model_features.keys())}"
        )

    return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]


def main():
    parser = ArgumentParser("Hugging Face ONNX Exporter tool")
    parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
    parser.add_argument(
        "--features",
        choices=["default"],
        default="default",
        help="Export the model with some additional features.",
    )
    parser.add_argument(
        "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
    )
    parser.add_argument(
        "--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
    )
    parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")

    # Retrieve CLI arguments
    args = parser.parse_args()
    args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")

    if not args.output.parent.exists():
        args.output.parent.mkdir(parents=True)

    # Allocate the model
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = get_model_from_features(args.features, args.model)
    model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
    onnx_config = model_onnx_config(model.config)

    # Ensure the requested opset is sufficient
    if args.opset < onnx_config.default_onnx_opset:
        raise ValueError(
            f"Opset {args.opset} is not sufficient to export {model_kind}. "
            f"At least  {onnx_config.default_onnx_opset} is required."
        )

    onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)

    validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
    logger.info(f"All good, model saved at: {args.output.as_posix()}")


if __name__ == "__main__":
    logger = logging.get_logger("transformers.onnx")  # pylint: disable=invalid-name
    logger.setLevel(logging.INFO)
    main()