Spaces:
Running
Running
File size: 11,953 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "tensor_proto_util.h"
#include <string>
#include <vector>
#include "onnx/common/platform_helpers.h"
#include "onnx/defs/data_type_utils.h"
#include "onnx/defs/shape_inference.h"
namespace ONNX_NAMESPACE {
#define DEFINE_TO_TENSOR_ONE(type, enumType, field) \
template <> \
TensorProto ToTensor<type>(const type& value) { \
TensorProto t; \
t.set_data_type(enumType); \
t.add_##field##_data(value); \
return t; \
}
#define DEFINE_TO_TENSOR_LIST(type, enumType, field) \
template <> \
TensorProto ToTensor<type>(const std::vector<type>& values) { \
TensorProto t; \
t.clear_##field##_data(); \
t.set_data_type(enumType); \
for (const type& val : values) { \
t.add_##field##_data(val); \
} \
return t; \
}
#define DEFINE_PARSE_DATA(type, typed_data_fetch, tensorproto_datatype) \
template <> \
const std::vector<type> ParseData(const TensorProto* tensor_proto) { \
if (!tensor_proto->has_data_type() || tensor_proto->data_type() == TensorProto_DataType_UNDEFINED) { \
fail_shape_inference("The type of tensor: ", tensor_proto->name(), " is undefined so it cannot be parsed."); \
} else if (tensor_proto->data_type() != tensorproto_datatype) { \
fail_shape_inference( \
"ParseData type mismatch for tensor: ", \
tensor_proto->name(), \
". Expected:", \
Utils::DataTypeUtils::ToDataTypeString(tensorproto_datatype), \
" Actual:", \
Utils::DataTypeUtils::ToDataTypeString(tensor_proto->data_type())); \
} \
std::vector<type> res; \
if (tensor_proto->has_data_location() && tensor_proto->data_location() == TensorProto_DataLocation_EXTERNAL) { \
fail_shape_inference( \
"Cannot parse data from external tensors. Please ", \
"load external data into raw data for tensor: ", \
tensor_proto->name()); \
} else if (!tensor_proto->has_raw_data()) { \
const auto& data = tensor_proto->typed_data_fetch(); \
int expected_size = 1; \
for (int i = 0; i < tensor_proto->dims_size(); ++i) { \
expected_size *= tensor_proto->dims(i); \
} \
if (tensor_proto->dims_size() != 0 && data.size() != expected_size) { \
fail_shape_inference( \
"Data size mismatch. Tensor: ", \
tensor_proto->name(), \
" expected size ", \
expected_size, \
" does not match the actual size", \
data.size()); \
} \
res.insert(res.end(), data.begin(), data.end()); \
return res; \
} \
if (tensor_proto->data_type() == TensorProto_DataType_STRING) { \
fail_shape_inference( \
tensor_proto->name(), \
" data type is string. string", \
" content is required to be stored in repeated bytes string_data field.", \
" raw_data type cannot be string."); \
} \
/* The given tensor does have raw_data itself so parse it by given type */ \
/* make copy as we may have to reverse bytes */ \
std::string raw_data = tensor_proto->raw_data(); \
if (raw_data.empty()) { \
return res; \
} \
/* okay to remove const qualifier as we have already made a copy */ \
char* bytes = const_cast<char*>(raw_data.c_str()); \
/* onnx is little endian serialized always-tweak byte order if needed */ \
if (!is_processor_little_endian()) { \
const size_t element_size = sizeof(type); \
const size_t num_elements = raw_data.size() / element_size; \
for (size_t i = 0; i < num_elements; ++i) { \
char* start_byte = bytes + i * element_size; \
char* end_byte = start_byte + element_size - 1; \
/* keep swapping */ \
for (size_t count = 0; count < element_size / 2; ++count) { \
char temp = *start_byte; \
*start_byte = *end_byte; \
*end_byte = temp; \
++start_byte; \
--end_byte; \
} \
} \
} \
/* raw_data.c_str()/bytes is a byte array and may not be properly */ \
/* aligned for the underlying type */ \
/* We need to copy the raw_data.c_str()/bytes as byte instead of */ \
/* copying as the underlying type, otherwise we may hit memory */ \
/* misalignment issues on certain platforms, such as arm32-v7a */ \
const size_t raw_data_size = raw_data.size(); \
res.resize(raw_data_size / sizeof(type)); \
memcpy(reinterpret_cast<char*>(res.data()), bytes, raw_data_size); \
return res; \
}
DEFINE_TO_TENSOR_ONE(float, TensorProto_DataType_FLOAT, float)
DEFINE_TO_TENSOR_ONE(bool, TensorProto_DataType_BOOL, int32)
DEFINE_TO_TENSOR_ONE(int32_t, TensorProto_DataType_INT32, int32)
DEFINE_TO_TENSOR_ONE(int64_t, TensorProto_DataType_INT64, int64)
DEFINE_TO_TENSOR_ONE(uint64_t, TensorProto_DataType_UINT64, uint64)
DEFINE_TO_TENSOR_ONE(double, TensorProto_DataType_DOUBLE, double)
DEFINE_TO_TENSOR_ONE(std::string, TensorProto_DataType_STRING, string)
DEFINE_TO_TENSOR_LIST(float, TensorProto_DataType_FLOAT, float)
DEFINE_TO_TENSOR_LIST(bool, TensorProto_DataType_BOOL, int32)
DEFINE_TO_TENSOR_LIST(int32_t, TensorProto_DataType_INT32, int32)
DEFINE_TO_TENSOR_LIST(int64_t, TensorProto_DataType_INT64, int64)
DEFINE_TO_TENSOR_LIST(uint64_t, TensorProto_DataType_UINT64, uint64)
DEFINE_TO_TENSOR_LIST(double, TensorProto_DataType_DOUBLE, double)
DEFINE_TO_TENSOR_LIST(std::string, TensorProto_DataType_STRING, string)
DEFINE_PARSE_DATA(int32_t, int32_data, TensorProto_DataType_INT32)
DEFINE_PARSE_DATA(int64_t, int64_data, TensorProto_DataType_INT64)
DEFINE_PARSE_DATA(float, float_data, TensorProto_DataType_FLOAT)
DEFINE_PARSE_DATA(double, double_data, TensorProto_DataType_DOUBLE)
#undef DEFINE_PARSE_DATA
} // namespace ONNX_NAMESPACE
|