Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Tuple | |
import torch | |
from torch import Tensor, nn | |
from torch.autograd import Function | |
from torch.autograd.function import once_differentiable | |
from torch.nn.modules.utils import _pair | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext( | |
'_ext', ['correlation_forward', 'correlation_backward']) | |
class CorrelationFunction(Function): | |
def forward(ctx, | |
input1: Tensor, | |
input2: Tensor, | |
kernel_size: int = 1, | |
max_displacement: int = 1, | |
stride: int = 1, | |
padding: int = 1, | |
dilation: int = 1, | |
dilation_patch: int = 1) -> Tensor: | |
ctx.save_for_backward(input1, input2) | |
kH, kW = ctx.kernel_size = _pair(kernel_size) | |
patch_size = max_displacement * 2 + 1 | |
ctx.patch_size = patch_size | |
dH, dW = ctx.stride = _pair(stride) | |
padH, padW = ctx.padding = _pair(padding) | |
dilationH, dilationW = ctx.dilation = _pair(dilation) | |
dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair( | |
dilation_patch) | |
output_size = CorrelationFunction._output_size(ctx, input1) | |
output = input1.new_zeros(output_size) | |
ext_module.correlation_forward( | |
input1, | |
input2, | |
output, | |
kH=kH, | |
kW=kW, | |
patchH=patch_size, | |
patchW=patch_size, | |
padH=padH, | |
padW=padW, | |
dilationH=dilationH, | |
dilationW=dilationW, | |
dilation_patchH=dilation_patchH, | |
dilation_patchW=dilation_patchW, | |
dH=dH, | |
dW=dW) | |
return output | |
def backward( | |
ctx, grad_output: Tensor | |
) -> Tuple[Tensor, Tensor, None, None, None, None, None, None]: | |
input1, input2 = ctx.saved_tensors | |
kH, kW = ctx.kernel_size | |
patch_size = ctx.patch_size | |
padH, padW = ctx.padding | |
dilationH, dilationW = ctx.dilation | |
dilation_patchH, dilation_patchW = ctx.dilation_patch | |
dH, dW = ctx.stride | |
grad_input1 = torch.zeros_like(input1) | |
grad_input2 = torch.zeros_like(input2) | |
ext_module.correlation_backward( | |
grad_output, | |
input1, | |
input2, | |
grad_input1, | |
grad_input2, | |
kH=kH, | |
kW=kW, | |
patchH=patch_size, | |
patchW=patch_size, | |
padH=padH, | |
padW=padW, | |
dilationH=dilationH, | |
dilationW=dilationW, | |
dilation_patchH=dilation_patchH, | |
dilation_patchW=dilation_patchW, | |
dH=dH, | |
dW=dW) | |
return grad_input1, grad_input2, None, None, None, None, None, None | |
def _output_size(ctx, input1): | |
iH, iW = input1.size(2), input1.size(3) | |
batch_size = input1.size(0) | |
kH, kW = ctx.kernel_size | |
patch_size = ctx.patch_size | |
dH, dW = ctx.stride | |
padH, padW = ctx.padding | |
dilationH, dilationW = ctx.dilation | |
dilatedKH = (kH - 1) * dilationH + 1 | |
dilatedKW = (kW - 1) * dilationW + 1 | |
oH = int((iH + 2 * padH - dilatedKH) / dH + 1) | |
oW = int((iW + 2 * padW - dilatedKW) / dW + 1) | |
output_size = (batch_size, patch_size, patch_size, oH, oW) | |
return output_size | |
class Correlation(nn.Module): | |
r"""Correlation operator | |
This correlation operator works for optical flow correlation computation. | |
There are two batched tensors with shape :math:`(N, C, H, W)`, | |
and the correlation output's shape is :math:`(N, max\_displacement \times | |
2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})` | |
where | |
.. math:: | |
H_{out} = \left\lfloor\frac{H_{in} + 2 \times padding - | |
dilation \times (kernel\_size - 1) - 1} | |
{stride} + 1\right\rfloor | |
.. math:: | |
W_{out} = \left\lfloor\frac{W_{in} + 2 \times padding - dilation | |
\times (kernel\_size - 1) - 1} | |
{stride} + 1\right\rfloor | |
the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding | |
window convolution between input1 and shifted input2, | |
.. math:: | |
Corr(N_i, dx, dy) = | |
\sum_{c=0}^{C-1} | |
input1(N_i, c) \star | |
\mathcal{S}(input2(N_i, c), dy, dx) | |
where :math:`\star` is the valid 2d sliding window convolution operator, | |
and :math:`\mathcal{S}` means shifting the input features (auto-complete | |
zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in | |
[-max\_displacement \times dilation\_patch, max\_displacement \times | |
dilation\_patch]`. | |
Args: | |
kernel_size (int): The size of sliding window i.e. local neighborhood | |
representing the center points and involved in correlation | |
computation. Defaults to 1. | |
max_displacement (int): The radius for computing correlation volume, | |
but the actual working space can be dilated by dilation_patch. | |
Defaults to 1. | |
stride (int): The stride of the sliding blocks in the input spatial | |
dimensions. Defaults to 1. | |
padding (int): Zero padding added to all four sides of the input1. | |
Defaults to 0. | |
dilation (int): The spacing of local neighborhood that will involved | |
in correlation. Defaults to 1. | |
dilation_patch (int): The spacing between position need to compute | |
correlation. Defaults to 1. | |
""" | |
def __init__(self, | |
kernel_size: int = 1, | |
max_displacement: int = 1, | |
stride: int = 1, | |
padding: int = 0, | |
dilation: int = 1, | |
dilation_patch: int = 1) -> None: | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.max_displacement = max_displacement | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.dilation_patch = dilation_patch | |
def forward(self, input1: Tensor, input2: Tensor) -> Tensor: | |
return CorrelationFunction.apply(input1, input2, self.kernel_size, | |
self.max_displacement, self.stride, | |
self.padding, self.dilation, | |
self.dilation_patch) | |
def __repr__(self) -> str: | |
s = self.__class__.__name__ | |
s += f'(kernel_size={self.kernel_size}, ' | |
s += f'max_displacement={self.max_displacement}, ' | |
s += f'stride={self.stride}, ' | |
s += f'padding={self.padding}, ' | |
s += f'dilation={self.dilation}, ' | |
s += f'dilation_patch={self.dilation_patch})' | |
return s | |