GameServerO / MLPY /Lib /site-packages /onnx /test /training_tool_test.py
Kano001's picture
Upload 2707 files
dc2106c verified
raw
history blame
3.78 kB
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper, shape_inference
class TestTrainingTool(unittest.TestCase):
def test_training_info_proto(self) -> None:
# Inference graph.
A_shape = [2, 2]
A_name = "A"
A = np.random.rand(*A_shape).astype(np.float32)
A_initializer = numpy_helper.from_array(A, name=A_name)
A_value_info = helper.make_tensor_value_info(A_name, TensorProto.FLOAT, A_shape)
B_shape = [2, 2]
B_name = "B"
B = np.random.rand(*B_shape).astype(np.float32)
B_initializer = numpy_helper.from_array(B, name=B_name)
B_value_info = helper.make_tensor_value_info(B_name, TensorProto.FLOAT, B_shape)
C_shape = [2, 2]
C_name = "C"
C_value_info = helper.make_tensor_value_info(C_name, TensorProto.FLOAT, C_shape)
inference_node = helper.make_node(
"MatMul", inputs=[A_name, B_name], outputs=[C_name]
)
inference_graph = helper.make_graph(
[inference_node],
"simple_inference",
[A_value_info, B_value_info],
[C_value_info],
[A_initializer, B_initializer],
)
# Training graph
X_shape = [2, 2]
X_name = "X"
X = np.random.rand(*X_shape).astype(np.float32)
X_initializer = numpy_helper.from_array(X, name=X_name)
X_value_info = helper.make_tensor_value_info(X_name, TensorProto.FLOAT, X_shape)
Y_shape = [2, 2]
Y_name = "Y"
Y_value_info = helper.make_tensor_value_info(Y_name, TensorProto.FLOAT, Y_shape)
node = helper.make_node(
"MatMul",
inputs=[X_name, C_name], # tensor "C" is from inference graph.
outputs=[Y_name],
)
training_graph = helper.make_graph(
[node], "simple_training", [X_value_info], [Y_value_info], [X_initializer]
)
# Capture assignment of B <--- Y.
training_info = helper.make_training_info(
training_graph, [(B_name, Y_name)], None, None
)
# Create a model with both inference and training information.
model = helper.make_model(inference_graph)
# Check if the inference-only part is correct.
onnx.checker.check_model(model)
# Insert training information.
new_training_info = model.training_info.add()
new_training_info.CopyFrom(training_info)
# Generate the actual training graph from training information so that
# we can run onnx checker to check if the full training graph is a valid
# graph. As defined in spec, full training graph forms by concatenating
# corresponding fields.
full_training_graph = helper.make_graph(
list(model.graph.node) + list(model.training_info[0].algorithm.node),
"full_training_graph",
list(model.graph.input) + list(model.training_info[0].algorithm.input),
list(model.graph.output) + list(model.training_info[0].algorithm.output),
list(model.graph.initializer)
+ list(model.training_info[0].algorithm.initializer),
)
# Wrap full training graph as a ModelProto so that we can run checker.
full_training_model = helper.make_model(full_training_graph)
full_training_model_with_shapes = shape_inference.infer_shapes(
full_training_model
)
onnx.checker.check_model(full_training_model_with_shapes)
if __name__ == "__main__":
unittest.main()