File size: 3,776 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0

import unittest

import numpy as np

import onnx
from onnx import TensorProto, helper, numpy_helper, shape_inference


class TestTrainingTool(unittest.TestCase):
    def test_training_info_proto(self) -> None:
        # Inference graph.
        A_shape = [2, 2]
        A_name = "A"
        A = np.random.rand(*A_shape).astype(np.float32)
        A_initializer = numpy_helper.from_array(A, name=A_name)
        A_value_info = helper.make_tensor_value_info(A_name, TensorProto.FLOAT, A_shape)

        B_shape = [2, 2]
        B_name = "B"
        B = np.random.rand(*B_shape).astype(np.float32)
        B_initializer = numpy_helper.from_array(B, name=B_name)
        B_value_info = helper.make_tensor_value_info(B_name, TensorProto.FLOAT, B_shape)

        C_shape = [2, 2]
        C_name = "C"
        C_value_info = helper.make_tensor_value_info(C_name, TensorProto.FLOAT, C_shape)

        inference_node = helper.make_node(
            "MatMul", inputs=[A_name, B_name], outputs=[C_name]
        )

        inference_graph = helper.make_graph(
            [inference_node],
            "simple_inference",
            [A_value_info, B_value_info],
            [C_value_info],
            [A_initializer, B_initializer],
        )

        # Training graph
        X_shape = [2, 2]
        X_name = "X"
        X = np.random.rand(*X_shape).astype(np.float32)
        X_initializer = numpy_helper.from_array(X, name=X_name)
        X_value_info = helper.make_tensor_value_info(X_name, TensorProto.FLOAT, X_shape)

        Y_shape = [2, 2]
        Y_name = "Y"
        Y_value_info = helper.make_tensor_value_info(Y_name, TensorProto.FLOAT, Y_shape)

        node = helper.make_node(
            "MatMul",
            inputs=[X_name, C_name],  # tensor "C" is from inference graph.
            outputs=[Y_name],
        )

        training_graph = helper.make_graph(
            [node], "simple_training", [X_value_info], [Y_value_info], [X_initializer]
        )

        # Capture assignment of B <--- Y.
        training_info = helper.make_training_info(
            training_graph, [(B_name, Y_name)], None, None
        )

        # Create a model with both inference and training information.
        model = helper.make_model(inference_graph)
        # Check if the inference-only part is correct.
        onnx.checker.check_model(model)
        # Insert training information.
        new_training_info = model.training_info.add()
        new_training_info.CopyFrom(training_info)

        # Generate the actual training graph from training information so that
        # we can run onnx checker to check if the full training graph is a valid
        # graph. As defined in spec, full training graph forms by concatenating
        # corresponding fields.
        full_training_graph = helper.make_graph(
            list(model.graph.node) + list(model.training_info[0].algorithm.node),
            "full_training_graph",
            list(model.graph.input) + list(model.training_info[0].algorithm.input),
            list(model.graph.output) + list(model.training_info[0].algorithm.output),
            list(model.graph.initializer)
            + list(model.training_info[0].algorithm.initializer),
        )

        # Wrap full training graph as a ModelProto so that we can run checker.
        full_training_model = helper.make_model(full_training_graph)
        full_training_model_with_shapes = shape_inference.infer_shapes(
            full_training_model
        )
        onnx.checker.check_model(full_training_model_with_shapes)


if __name__ == "__main__":
    unittest.main()