Spaces:
Running
Running
namespace at::native { | |
template<typename scalar_t> | |
static inline std::vector<int> generate_intervals( | |
scalar_t sample, | |
int64_t inputSize, | |
int64_t outputSize, | |
int64_t poolSize) { | |
std::vector<int> sequence(outputSize); | |
if (outputSize > 1) { | |
scalar_t alpha = static_cast<scalar_t>(inputSize - poolSize) / | |
static_cast<scalar_t>(outputSize - 1); | |
for (const auto i : c10::irange(outputSize - 1)) { | |
sequence[i] = | |
static_cast<int>((i + sample) * alpha) - static_cast<int>(sample * alpha); | |
} | |
} | |
if (outputSize > 0) { | |
sequence[outputSize - 1] = inputSize - poolSize; | |
} | |
return sequence; | |
} | |
template <int64_t ndim> | |
static inline void fractional_max_pool_check_shape( | |
const Tensor& input, | |
const Tensor& randomSamples) { | |
TORCH_CHECK( | |
input.scalar_type() == randomSamples.scalar_type(), | |
"Expect _random_samples to have the same dtype as input"); | |
int64_t ndimension = randomSamples.ndimension(); | |
TORCH_CHECK( | |
ndimension == 3, | |
"Expect _random_samples to have 3 dimensions, got ", ndimension); | |
int64_t N = randomSamples.size(0); | |
int64_t C = randomSamples.size(1); | |
int64_t D = randomSamples.size(2); | |
int64_t input_batch, input_channel; | |
if (ndim == 2) { | |
// fractional_max_pool2d | |
if (input.ndimension() == 3) { | |
input_batch = 1; | |
input_channel = input.size(0); | |
} else { | |
input_batch = input.size(0); | |
input_channel = input.size(1); | |
} | |
} else { | |
// factional_max_pool3d | |
if (input.ndimension() == 4) { | |
input_batch = 1; | |
input_channel = input.size(0); | |
} else { | |
input_batch = input.size(0); | |
input_channel = input.size(1); | |
} | |
} | |
TORCH_CHECK( | |
N >= input_batch, | |
"Expect _random_samples.size(0) no less then input batch size."); | |
TORCH_CHECK( | |
C == input_channel, | |
"Expect _random_samples.size(1) equals to input channel size."); | |
TORCH_CHECK( | |
D == ndim, | |
"Expect _random_samples.size(2) equals to ", ndim, "; got ", D, "."); | |
} | |
} // namespace at::native | |