Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
7.89 kB
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
import argparse
import json
import os
import shutil
import warnings
import onnx.backend.test.case.model as model_test
import onnx.backend.test.case.node as node_test
from onnx import ONNX_ML, TensorProto, numpy_helper
TOP_DIR = os.path.realpath(os.path.dirname(__file__))
DATA_DIR = os.path.join(TOP_DIR, "data")
def generate_data(args: argparse.Namespace) -> None:
def prepare_dir(path: str) -> None:
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
# Clean the output directory before generating data for node testcases
# It is used to check new generated data is correct in CIs
node_root = os.path.join(args.output, "node")
original_dir_number = len(
[name for name in os.listdir(node_root) if os.path.isfile(name)]
)
if args.clean and os.path.exists(node_root):
for sub_dir in os.listdir(node_root):
if ONNX_ML or not sub_dir.startswith("test_ai_onnx_ml_"):
shutil.rmtree(os.path.join(node_root, sub_dir))
cases = model_test.collect_testcases()
# If op_type is specified, only include those testcases including the given operator
# Otherwise, include all of the testcases
if args.diff:
cases += node_test.collect_diff_testcases()
else:
cases += node_test.collect_testcases(args.op_type)
node_number = 0
for case in cases:
output_dir = os.path.join(args.output, case.kind, case.name)
prepare_dir(output_dir)
if case.kind == "node":
node_number += 1
if case.kind == "real":
with open(os.path.join(output_dir, "data.json"), "w") as fi:
json.dump(
{
"url": case.url,
"model_name": case.model_name,
"rtol": case.rtol,
"atol": case.atol,
},
fi,
sort_keys=True,
)
else:
assert case.model
with open(os.path.join(output_dir, "model.onnx"), "wb") as f:
f.write(case.model.SerializeToString())
assert case.data_sets
for i, (inputs, outputs) in enumerate(case.data_sets):
data_set_dir = os.path.join(output_dir, f"test_data_set_{i}")
prepare_dir(data_set_dir)
for j, input in enumerate(inputs):
with open(os.path.join(data_set_dir, f"input_{j}.pb"), "wb") as f:
if case.model.graph.input[j].type.HasField("map_type"):
f.write(
numpy_helper.from_dict(
input, case.model.graph.input[j].name
).SerializeToString()
)
elif case.model.graph.input[j].type.HasField("sequence_type"):
f.write(
numpy_helper.from_list(
input, case.model.graph.input[j].name
).SerializeToString()
)
elif case.model.graph.input[j].type.HasField("optional_type"):
f.write(
numpy_helper.from_optional(
input, case.model.graph.input[j].name
).SerializeToString()
)
else:
assert case.model.graph.input[j].type.HasField(
"tensor_type"
)
if isinstance(input, TensorProto):
f.write(input.SerializeToString())
else:
f.write(
numpy_helper.from_array(
input, case.model.graph.input[j].name
).SerializeToString()
)
for j, output in enumerate(outputs):
with open(os.path.join(data_set_dir, f"output_{j}.pb"), "wb") as f:
if case.model.graph.output[j].type.HasField("map_type"):
f.write(
numpy_helper.from_dict(
output, case.model.graph.output[j].name
).SerializeToString()
)
elif case.model.graph.output[j].type.HasField("sequence_type"):
f.write(
numpy_helper.from_list(
output, case.model.graph.output[j].name
).SerializeToString()
)
elif case.model.graph.output[j].type.HasField("optional_type"):
f.write(
numpy_helper.from_optional(
output, case.model.graph.output[j].name
).SerializeToString()
)
else:
assert case.model.graph.output[j].type.HasField(
"tensor_type"
)
if isinstance(output, TensorProto):
f.write(output.SerializeToString())
else:
f.write(
numpy_helper.from_array(
output, case.model.graph.output[j].name
).SerializeToString()
)
if not args.clean and node_number != original_dir_number:
warnings.warn(
"There are some models under 'onnx/backend/test/data/node' which cannot not"
" be generated by the script from 'onnx/backend/test/case/node'. Please add"
" '--clean' option for 'python onnx/backend/test/cmd_tools.py generate-data'"
" to cleanup the existing directories and regenerate them.",
Warning,
stacklevel=2,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser("backend-test-tools")
subparsers = parser.add_subparsers()
subparser = subparsers.add_parser(
"generate-data", help="convert testcases to test data."
)
subparser.add_argument(
"-c",
"--clean",
default=False,
action="store_true",
help="Clean the output directory before generating data for node testcases.",
)
subparser.add_argument(
"-o",
"--output",
default=DATA_DIR,
help="output directory (default: %(default)s)",
)
subparser.add_argument(
"-t",
"--op_type",
default=None,
help="op_type for test case generation. (generates test data for the specified op_type only.)",
)
subparser.add_argument(
"-d",
"--diff",
default=False,
action="store_true",
help="only generates test data for those changed files (compared to the main branch).",
)
subparser.set_defaults(func=generate_data)
return parser.parse_args()
def main() -> None:
args = parse_args()
args.func(args)
if __name__ == "__main__":
main()