# Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 import unittest import numpy as np import onnx from onnx import TensorProto, helper, numpy_helper, shape_inference class TestTrainingTool(unittest.TestCase): def test_training_info_proto(self) -> None: # Inference graph. A_shape = [2, 2] A_name = "A" A = np.random.rand(*A_shape).astype(np.float32) A_initializer = numpy_helper.from_array(A, name=A_name) A_value_info = helper.make_tensor_value_info(A_name, TensorProto.FLOAT, A_shape) B_shape = [2, 2] B_name = "B" B = np.random.rand(*B_shape).astype(np.float32) B_initializer = numpy_helper.from_array(B, name=B_name) B_value_info = helper.make_tensor_value_info(B_name, TensorProto.FLOAT, B_shape) C_shape = [2, 2] C_name = "C" C_value_info = helper.make_tensor_value_info(C_name, TensorProto.FLOAT, C_shape) inference_node = helper.make_node( "MatMul", inputs=[A_name, B_name], outputs=[C_name] ) inference_graph = helper.make_graph( [inference_node], "simple_inference", [A_value_info, B_value_info], [C_value_info], [A_initializer, B_initializer], ) # Training graph X_shape = [2, 2] X_name = "X" X = np.random.rand(*X_shape).astype(np.float32) X_initializer = numpy_helper.from_array(X, name=X_name) X_value_info = helper.make_tensor_value_info(X_name, TensorProto.FLOAT, X_shape) Y_shape = [2, 2] Y_name = "Y" Y_value_info = helper.make_tensor_value_info(Y_name, TensorProto.FLOAT, Y_shape) node = helper.make_node( "MatMul", inputs=[X_name, C_name], # tensor "C" is from inference graph. outputs=[Y_name], ) training_graph = helper.make_graph( [node], "simple_training", [X_value_info], [Y_value_info], [X_initializer] ) # Capture assignment of B <--- Y. training_info = helper.make_training_info( training_graph, [(B_name, Y_name)], None, None ) # Create a model with both inference and training information. model = helper.make_model(inference_graph) # Check if the inference-only part is correct. onnx.checker.check_model(model) # Insert training information. new_training_info = model.training_info.add() new_training_info.CopyFrom(training_info) # Generate the actual training graph from training information so that # we can run onnx checker to check if the full training graph is a valid # graph. As defined in spec, full training graph forms by concatenating # corresponding fields. full_training_graph = helper.make_graph( list(model.graph.node) + list(model.training_info[0].algorithm.node), "full_training_graph", list(model.graph.input) + list(model.training_info[0].algorithm.input), list(model.graph.output) + list(model.training_info[0].algorithm.output), list(model.graph.initializer) + list(model.training_info[0].algorithm.initializer), ) # Wrap full training graph as a ModelProto so that we can run checker. full_training_model = helper.make_model(full_training_graph) full_training_model_with_shapes = shape_inference.infer_shapes( full_training_model ) onnx.checker.check_model(full_training_model_with_shapes) if __name__ == "__main__": unittest.main()