Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import unittest | |
import numpy as np | |
import onnx | |
from onnx import TensorProto, helper, numpy_helper, shape_inference | |
class TestTrainingTool(unittest.TestCase): | |
def test_training_info_proto(self) -> None: | |
# Inference graph. | |
A_shape = [2, 2] | |
A_name = "A" | |
A = np.random.rand(*A_shape).astype(np.float32) | |
A_initializer = numpy_helper.from_array(A, name=A_name) | |
A_value_info = helper.make_tensor_value_info(A_name, TensorProto.FLOAT, A_shape) | |
B_shape = [2, 2] | |
B_name = "B" | |
B = np.random.rand(*B_shape).astype(np.float32) | |
B_initializer = numpy_helper.from_array(B, name=B_name) | |
B_value_info = helper.make_tensor_value_info(B_name, TensorProto.FLOAT, B_shape) | |
C_shape = [2, 2] | |
C_name = "C" | |
C_value_info = helper.make_tensor_value_info(C_name, TensorProto.FLOAT, C_shape) | |
inference_node = helper.make_node( | |
"MatMul", inputs=[A_name, B_name], outputs=[C_name] | |
) | |
inference_graph = helper.make_graph( | |
[inference_node], | |
"simple_inference", | |
[A_value_info, B_value_info], | |
[C_value_info], | |
[A_initializer, B_initializer], | |
) | |
# Training graph | |
X_shape = [2, 2] | |
X_name = "X" | |
X = np.random.rand(*X_shape).astype(np.float32) | |
X_initializer = numpy_helper.from_array(X, name=X_name) | |
X_value_info = helper.make_tensor_value_info(X_name, TensorProto.FLOAT, X_shape) | |
Y_shape = [2, 2] | |
Y_name = "Y" | |
Y_value_info = helper.make_tensor_value_info(Y_name, TensorProto.FLOAT, Y_shape) | |
node = helper.make_node( | |
"MatMul", | |
inputs=[X_name, C_name], # tensor "C" is from inference graph. | |
outputs=[Y_name], | |
) | |
training_graph = helper.make_graph( | |
[node], "simple_training", [X_value_info], [Y_value_info], [X_initializer] | |
) | |
# Capture assignment of B <--- Y. | |
training_info = helper.make_training_info( | |
training_graph, [(B_name, Y_name)], None, None | |
) | |
# Create a model with both inference and training information. | |
model = helper.make_model(inference_graph) | |
# Check if the inference-only part is correct. | |
onnx.checker.check_model(model) | |
# Insert training information. | |
new_training_info = model.training_info.add() | |
new_training_info.CopyFrom(training_info) | |
# Generate the actual training graph from training information so that | |
# we can run onnx checker to check if the full training graph is a valid | |
# graph. As defined in spec, full training graph forms by concatenating | |
# corresponding fields. | |
full_training_graph = helper.make_graph( | |
list(model.graph.node) + list(model.training_info[0].algorithm.node), | |
"full_training_graph", | |
list(model.graph.input) + list(model.training_info[0].algorithm.input), | |
list(model.graph.output) + list(model.training_info[0].algorithm.output), | |
list(model.graph.initializer) | |
+ list(model.training_info[0].algorithm.initializer), | |
) | |
# Wrap full training graph as a ModelProto so that we can run checker. | |
full_training_model = helper.make_model(full_training_graph) | |
full_training_model_with_shapes = shape_inference.infer_shapes( | |
full_training_model | |
) | |
onnx.checker.check_model(full_training_model_with_shapes) | |
if __name__ == "__main__": | |
unittest.main() | |