Luisgust commited on
Commit
c409b28
·
verified ·
1 Parent(s): 20c7783

Create vtoonify/model/raft/alt_cuda_corr/correlation.cpp

Browse files
vtoonify/model/raft/alt_cuda_corr/correlation.cpp ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <torch/extension.h>
3
+ #include <vector>
4
+
5
+ // CUDA forward declarations
6
+ std::vector<torch::Tensor> corr_cuda_forward(
7
+ torch::Tensor fmap1,
8
+ torch::Tensor fmap2,
9
+ torch::Tensor coords,
10
+ int radius);
11
+
12
+ std::vector<torch::Tensor> corr_cuda_backward(
13
+ torch::Tensor fmap1,
14
+ torch::Tensor fmap2,
15
+ torch::Tensor coords,
16
+ torch::Tensor corr_grad,
17
+ int radius);
18
+
19
+ // C++ interface
20
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
21
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
22
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
23
+
24
+ std::vector<torch::Tensor> corr_forward(
25
+ torch::Tensor fmap1,
26
+ torch::Tensor fmap2,
27
+ torch::Tensor coords,
28
+ int radius) {
29
+ CHECK_INPUT(fmap1);
30
+ CHECK_INPUT(fmap2);
31
+ CHECK_INPUT(coords);
32
+
33
+ return corr_cuda_forward(fmap1, fmap2, coords, radius);
34
+ }
35
+
36
+
37
+ std::vector<torch::Tensor> corr_backward(
38
+ torch::Tensor fmap1,
39
+ torch::Tensor fmap2,
40
+ torch::Tensor coords,
41
+ torch::Tensor corr_grad,
42
+ int radius) {
43
+ CHECK_INPUT(fmap1);
44
+ CHECK_INPUT(fmap2);
45
+ CHECK_INPUT(coords);
46
+ CHECK_INPUT(corr_grad);
47
+
48
+ return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
49
+ }
50
+
51
+
52
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
53
+ m.def("forward", &corr_forward, "CORR forward");
54
+ m.def("backward", &corr_backward, "CORR backward");
55
+ }