Spaces:
Sleeping
Sleeping
namespace at { | |
/* | |
[collapse dims] Updates sizes, and strides to reflect a "collapse" of | |
the info, possibly excluding the optional excludeDim. A "collapsed" version | |
of the info is the fewest dims that order the tensor's elements in the same | |
way as the original info. If excludeDim is specified, the collapse is the | |
fewest dims that order the tensor's elements as the original and preserve the | |
excluded dimension, unless the tensor collapses to a point. | |
This function returns a pair of values. | |
1) The (new) index of the preserved dimension if excludeDim is | |
specified. 0 if the tensor is collapsed to a point. -1 | |
otherwise. | |
2) The new number of dimensions. | |
*/ | |
template <typename T> | |
inline std::pair<int64_t, int64_t> collapse_dims( | |
T* sizes, | |
T* strides, | |
int64_t dims, | |
const int excludeDim = -1) { | |
TORCH_CHECK( | |
excludeDim >= -1 && excludeDim < dims, | |
"expected excluded dim between -1 and dims - 1"); | |
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; | |
int64_t newIndex = -1; | |
int64_t oldIndex = 0; | |
int64_t remappedExcludedDim = -1; | |
while (oldIndex < dims) { | |
// Finds a dimension to collapse into | |
for (; oldIndex < stopDim; ++oldIndex) { | |
if (sizes[oldIndex] == 1) { | |
continue; | |
} | |
++newIndex; | |
sizes[newIndex] = sizes[oldIndex]; | |
strides[newIndex] = strides[oldIndex]; | |
++oldIndex; | |
break; | |
} | |
// Collapses dims | |
for (; oldIndex < stopDim; ++oldIndex) { | |
if (sizes[oldIndex] == 1) { | |
continue; | |
} | |
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { | |
sizes[newIndex] *= sizes[oldIndex]; | |
strides[newIndex] = strides[oldIndex]; | |
} else { | |
++newIndex; | |
sizes[newIndex] = sizes[oldIndex]; | |
strides[newIndex] = strides[oldIndex]; | |
} | |
} | |
// Handles excludeDim being set (oldIndex == excludeDim) | |
if (oldIndex != dims) { | |
// Preserves excluded dimension | |
++newIndex; | |
sizes[newIndex] = sizes[oldIndex]; | |
strides[newIndex] = strides[oldIndex]; | |
remappedExcludedDim = newIndex; | |
// Restarts iteration after excludeDim | |
++oldIndex; | |
stopDim = dims; | |
} | |
} | |
// Handles special case of all dims size 1 | |
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { | |
dims = 1; | |
sizes[0] = 1; | |
strides[0] = 1; | |
return std::pair<int64_t, int64_t>(0, 1); | |
} | |
dims = newIndex + 1; | |
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims); | |
} | |
} // namespace at | |