Spaces:
Sleeping
Sleeping
File size: 3,776 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper, shape_inference
class TestTrainingTool(unittest.TestCase):
def test_training_info_proto(self) -> None:
# Inference graph.
A_shape = [2, 2]
A_name = "A"
A = np.random.rand(*A_shape).astype(np.float32)
A_initializer = numpy_helper.from_array(A, name=A_name)
A_value_info = helper.make_tensor_value_info(A_name, TensorProto.FLOAT, A_shape)
B_shape = [2, 2]
B_name = "B"
B = np.random.rand(*B_shape).astype(np.float32)
B_initializer = numpy_helper.from_array(B, name=B_name)
B_value_info = helper.make_tensor_value_info(B_name, TensorProto.FLOAT, B_shape)
C_shape = [2, 2]
C_name = "C"
C_value_info = helper.make_tensor_value_info(C_name, TensorProto.FLOAT, C_shape)
inference_node = helper.make_node(
"MatMul", inputs=[A_name, B_name], outputs=[C_name]
)
inference_graph = helper.make_graph(
[inference_node],
"simple_inference",
[A_value_info, B_value_info],
[C_value_info],
[A_initializer, B_initializer],
)
# Training graph
X_shape = [2, 2]
X_name = "X"
X = np.random.rand(*X_shape).astype(np.float32)
X_initializer = numpy_helper.from_array(X, name=X_name)
X_value_info = helper.make_tensor_value_info(X_name, TensorProto.FLOAT, X_shape)
Y_shape = [2, 2]
Y_name = "Y"
Y_value_info = helper.make_tensor_value_info(Y_name, TensorProto.FLOAT, Y_shape)
node = helper.make_node(
"MatMul",
inputs=[X_name, C_name], # tensor "C" is from inference graph.
outputs=[Y_name],
)
training_graph = helper.make_graph(
[node], "simple_training", [X_value_info], [Y_value_info], [X_initializer]
)
# Capture assignment of B <--- Y.
training_info = helper.make_training_info(
training_graph, [(B_name, Y_name)], None, None
)
# Create a model with both inference and training information.
model = helper.make_model(inference_graph)
# Check if the inference-only part is correct.
onnx.checker.check_model(model)
# Insert training information.
new_training_info = model.training_info.add()
new_training_info.CopyFrom(training_info)
# Generate the actual training graph from training information so that
# we can run onnx checker to check if the full training graph is a valid
# graph. As defined in spec, full training graph forms by concatenating
# corresponding fields.
full_training_graph = helper.make_graph(
list(model.graph.node) + list(model.training_info[0].algorithm.node),
"full_training_graph",
list(model.graph.input) + list(model.training_info[0].algorithm.input),
list(model.graph.output) + list(model.training_info[0].algorithm.output),
list(model.graph.initializer)
+ list(model.training_info[0].algorithm.initializer),
)
# Wrap full training graph as a ModelProto so that we can run checker.
full_training_model = helper.make_model(full_training_graph)
full_training_model_with_shapes = shape_inference.infer_shapes(
full_training_model
)
onnx.checker.check_model(full_training_model_with_shapes)
if __name__ == "__main__":
unittest.main()
|