Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Non-local helper""" | |
import torch | |
import torch.nn as nn | |
class Nonlocal(nn.Module): | |
""" | |
Builds Non-local Neural Networks as a generic family of building | |
blocks for capturing long-range dependencies. Non-local Network | |
computes the response at a position as a weighted sum of the | |
features at all positions. This building block can be plugged into | |
many computer vision architectures. | |
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf | |
""" | |
def __init__( | |
self, | |
dim, | |
dim_inner, | |
pool_size=None, | |
instantiation="softmax", | |
zero_init_final_conv=False, | |
zero_init_final_norm=True, | |
norm_eps=1e-5, | |
norm_momentum=0.1, | |
norm_module=nn.BatchNorm3d, | |
): | |
""" | |
Args: | |
dim (int): number of dimension for the input. | |
dim_inner (int): number of dimension inside of the Non-local block. | |
pool_size (list): the kernel size of spatial temporal pooling, | |
temporal pool kernel size, spatial pool kernel size, spatial | |
pool kernel size in order. By default pool_size is None, | |
then there would be no pooling used. | |
instantiation (string): supports two different instantiation method: | |
"dot_product": normalizing correlation matrix with L2. | |
"softmax": normalizing correlation matrix with Softmax. | |
zero_init_final_conv (bool): If true, zero initializing the final | |
convolution of the Non-local block. | |
zero_init_final_norm (bool): | |
If true, zero initializing the final batch norm of the Non-local | |
block. | |
norm_module (nn.Module): nn.Module for the normalization layer. The | |
default is nn.BatchNorm3d. | |
""" | |
super(Nonlocal, self).__init__() | |
self.dim = dim | |
self.dim_inner = dim_inner | |
self.pool_size = pool_size | |
self.instantiation = instantiation | |
self.use_pool = ( | |
False | |
if pool_size is None | |
else any((size > 1 for size in pool_size)) | |
) | |
self.norm_eps = norm_eps | |
self.norm_momentum = norm_momentum | |
self._construct_nonlocal( | |
zero_init_final_conv, zero_init_final_norm, norm_module | |
) | |
def _construct_nonlocal( | |
self, zero_init_final_conv, zero_init_final_norm, norm_module | |
): | |
# Three convolution heads: theta, phi, and g. | |
self.conv_theta = nn.Conv3d( | |
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 | |
) | |
self.conv_phi = nn.Conv3d( | |
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 | |
) | |
self.conv_g = nn.Conv3d( | |
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0 | |
) | |
# Final convolution output. | |
self.conv_out = nn.Conv3d( | |
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0 | |
) | |
# Zero initializing the final convolution output. | |
self.conv_out.zero_init = zero_init_final_conv | |
# TODO: change the name to `norm` | |
self.bn = norm_module( | |
num_features=self.dim, | |
eps=self.norm_eps, | |
momentum=self.norm_momentum, | |
) | |
# Zero initializing the final bn. | |
self.bn.transform_final_bn = zero_init_final_norm | |
# Optional to add the spatial-temporal pooling. | |
if self.use_pool: | |
self.pool = nn.MaxPool3d( | |
kernel_size=self.pool_size, | |
stride=self.pool_size, | |
padding=[0, 0, 0], | |
) | |
def forward(self, x): | |
x_identity = x | |
N, C, T, H, W = x.size() | |
theta = self.conv_theta(x) | |
# Perform temporal-spatial pooling to reduce the computation. | |
if self.use_pool: | |
x = self.pool(x) | |
phi = self.conv_phi(x) | |
g = self.conv_g(x) | |
theta = theta.view(N, self.dim_inner, -1) | |
phi = phi.view(N, self.dim_inner, -1) | |
g = g.view(N, self.dim_inner, -1) | |
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW). | |
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi)) | |
# For original Non-local paper, there are two main ways to normalize | |
# the affinity tensor: | |
# 1) Softmax normalization (norm on exp). | |
# 2) dot_product normalization. | |
if self.instantiation == "softmax": | |
# Normalizing the affinity tensor theta_phi before softmax. | |
theta_phi = theta_phi * (self.dim_inner ** -0.5) | |
theta_phi = nn.functional.softmax(theta_phi, dim=2) | |
elif self.instantiation == "dot_product": | |
spatial_temporal_dim = theta_phi.shape[2] | |
theta_phi = theta_phi / spatial_temporal_dim | |
else: | |
raise NotImplementedError( | |
"Unknown norm type {}".format(self.instantiation) | |
) | |
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW). | |
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g)) | |
# (N, C, TxHxW) => (N, C, T, H, W). | |
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W) | |
p = self.conv_out(theta_phi_g) | |
p = self.bn(p) | |
return x_identity + p | |