Spaces:
Running
Running
File size: 1,273 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
/*
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
#include "onnx/defs/schema.h"
namespace ONNX_NAMESPACE {
// Declare training operators.
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad);
class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam);
// Iterate over schema from ai.onnx.training version 1
class OpSet_OnnxPreview_ver1 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adagrad)>());
fn(GetOpSchema<ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Adam)>());
}
};
// Register preview operators.
inline void RegisterOnnxPreviewOperatorSetSchema() {
// Preview operators should have only one version.
// If changes are needed for a specific preview operator,
// its spec should be modified without increasing its version.
RegisterOpSetSchema<OpSet_OnnxPreview_ver1>();
}
} // namespace ONNX_NAMESPACE
|