Spaces:
Sleeping
Sleeping
// Copyright (c) ONNX Project Contributors | |
/* | |
* SPDX-License-Identifier: Apache-2.0 | |
*/ | |
namespace ONNX_NAMESPACE { | |
namespace version_conversion { | |
ModelProto ConvertVersion(const ModelProto& mp_in, int target_version) { | |
// Get initial_opsetid from mp_in | |
OpSetID initial_struct(0); | |
for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { | |
if (it->domain() == "" || it->domain() == "ai.onnx") { | |
initial_struct.setVersion(it->version()); | |
break; | |
} | |
} | |
OpSetID target_struct = OpSetID(target_version); | |
DefaultVersionConverter v; | |
return v.convert_version(mp_in, initial_struct, target_struct); | |
} | |
void DefaultVersionConverter::convert_graph( | |
std::shared_ptr<Graph> g, | |
const OpSetID& initial_version, | |
const OpSetID& target_version) const { | |
assertNonNull(g); | |
// TODO: Move to Inter-Domain Converter | |
// Get initial model versions | |
// std::vector<OpSetID> initial_versions = g->opset_versions_mutable(); | |
// No conversion necessary if Model has single, equivalent opset version | |
// if (initial_versions.size() == 1 && initial_versions[0].version == | |
// target_version.version && initial_versions[0].domain == | |
// target_version.domain) { | |
// return mp_in; | |
// } | |
// Check if versions are valid | |
assertInVersionRange(initial_version.version()); | |
assertInVersionRange(target_version.version()); | |
// Iterate over all versions to target_version for specified | |
int64_t curr_version = initial_version.version(); | |
int64_t step; | |
if (target_version.version() > initial_version.version()) { | |
step = 1; | |
} else { | |
step = -1; | |
} | |
// Identify index of this domain in g.opset_versions | |
unsigned int domain_index = 0; | |
for (unsigned int i = 0; i < g->opset_versions_mutable().size(); i++) { | |
if (g->opset_versions_mutable()[i].domain() == "") { | |
domain_index = i; | |
} | |
} | |
while (curr_version != target_version.version()) { | |
debug( | |
"curr_version: " + ONNX_NAMESPACE::to_string(curr_version) + | |
", next_version: " + ONNX_NAMESPACE::to_string(curr_version + step)); | |
Node* cur_op; | |
graph_node_list_iterator it = g->begin(); | |
// Iterate through and call adapter returned by adapter_lookup for ops from | |
// current_version opset. We have to manipulate the iterator explicitly because cur_op | |
// might change when applying the adapter (e.g. for deprecated ops) | |
while (it != g->end()) { | |
cur_op = *it; | |
debug(std::string("Finding schema for ") + std::string(cur_op->kind().toString())); | |
const std::string op_name = cur_op->kind().toString(); | |
if (op_name == "ConstantFill") { | |
if (DEBUG) { | |
std::cerr | |
<< "Warning: skipping schema search for experimental op 'ConstantFill' and keeping the op as is. " | |
"Please be advised the converted model may not be working properly if target runtime does not support this " | |
"experimental op." | |
<< std::endl; | |
} | |
} else if (cur_op->domain() != "" && cur_op->domain() != "ai.onnx") { | |
if (DEBUG) { | |
std::cerr << "Warning: opset domain '" << cur_op->domain() << "' is not supported." << std::endl; | |
} | |
} else if (op_name != "Undefined" && op_name != "Captured") { | |
auto& op_domain_map = all_schemas.at(op_name); | |
OpSetID curr_id(curr_version); | |
OpSetID next_id(curr_version + step); | |
if (searchOpDomainMap(op_domain_map, curr_version, step)) { | |
// Op is specifically defined for this domain and version | |
auto& op_adapter = adapter_lookup(cur_op, curr_id, next_id); | |
// If adapter_lookup returns null, no adapter is present. | |
// Error thrown by adapter_lookup | |
if (DEBUG) { | |
std::cerr << "Applying adapter" << std::endl; | |
} | |
// adapt should handle replacing node in graph | |
cur_op = op_adapter.adapt(g, cur_op); | |
it = graph_node_list_iterator(cur_op, kNextDirection); | |
} | |
// Recursively convert any subgraph attributes | |
for (const auto& attr : cur_op->attributeNames()) { | |
if (cur_op->kindOf(attr) == AttributeKind::g) { | |
convert_graph(cur_op->g(attr), curr_id, next_id); | |
} | |
} | |
} | |
it++; | |
} | |
// Update model version | |
curr_version += step; | |
g->opset_versions_mutable()[domain_index].incrementVersion(step); | |
} | |
} | |
ModelProto DefaultVersionConverter::convert_version( | |
const ModelProto& mp_in, | |
const OpSetID& initial_version, | |
const OpSetID& target_version) const { | |
const std::string& initial_domain = initial_version.domain(); | |
const std::string& target_domain = target_version.domain(); | |
assertDefaultDomain(initial_domain, target_domain); | |
for (auto it = mp_in.opset_import().begin(); it != mp_in.opset_import().end(); ++it) { | |
if (it->domain() == initial_version.domain()) { | |
ONNX_ASSERTM( | |
initial_version.version() == it->version(), "initial_version does not reflect current state of model"); | |
} | |
} | |
std::shared_ptr<Graph> g(ImportModelProto(mp_in)); | |
convert_graph(g, initial_version, target_version); | |
// Export g as ModelProto | |
debug("Finished conversion; returning model"); | |
ModelProto mp_out = PrepareOutput(mp_in); | |
ExportModelProto(&mp_out, g); | |
return mp_out; | |
} | |
} // namespace version_conversion | |
} // namespace ONNX_NAMESPACE | |