|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value, |
|
const at::Tensor &spatial_shapes, |
|
const at::Tensor &level_start_index, |
|
const at::Tensor &sampling_loc, |
|
const at::Tensor &attn_weight, |
|
const int64_t im2col_step); |
|
|
|
std::vector<at::Tensor> ms_deform_attn_cuda_backward( |
|
const at::Tensor &value, const at::Tensor &spatial_shapes, |
|
const at::Tensor &level_start_index, const at::Tensor &sampling_loc, |
|
const at::Tensor &attn_weight, const at::Tensor &grad_output, |
|
const int64_t im2col_step); |
|
|