JF1999's picture
Upload 413 files
430de99 verified
// Copyright (c) Facebook, Inc. and its affiliates.
#pragma once
#include <torch/types.h>
namespace detectron2 {
at::Tensor ROIAlign_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
bool aligned);
at::Tensor ROIAlign_backward_cpu(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
bool aligned);
#if defined(WITH_CUDA) || defined(WITH_HIP)
at::Tensor ROIAlign_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
bool aligned);
at::Tensor ROIAlign_backward_cuda(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
bool aligned);
#endif
// Interface for Python
inline at::Tensor ROIAlign_forward(
const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio,
bool aligned) {
if (input.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIAlign_forward_cuda(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_forward_cpu(
input,
rois,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned);
}
inline at::Tensor ROIAlign_backward(
const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio,
bool aligned) {
if (grad.is_cuda()) {
#if defined(WITH_CUDA) || defined(WITH_HIP)
return ROIAlign_backward_cuda(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_backward_cpu(
grad,
rois,
spatial_scale,
pooled_height,
pooled_width,
batch_size,
channels,
height,
width,
sampling_ratio,
aligned);
}
} // namespace detectron2