Spaces:
Build error
Build error
File size: 5,395 Bytes
33f1db4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Non-local helper"""
import torch
import torch.nn as nn
class Nonlocal(nn.Module):
"""
Builds Non-local Neural Networks as a generic family of building
blocks for capturing long-range dependencies. Non-local Network
computes the response at a position as a weighted sum of the
features at all positions. This building block can be plugged into
many computer vision architectures.
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(
self,
dim,
dim_inner,
pool_size=None,
instantiation="softmax",
zero_init_final_conv=False,
zero_init_final_norm=True,
norm_eps=1e-5,
norm_momentum=0.1,
norm_module=nn.BatchNorm3d,
):
"""
Args:
dim (int): number of dimension for the input.
dim_inner (int): number of dimension inside of the Non-local block.
pool_size (list): the kernel size of spatial temporal pooling,
temporal pool kernel size, spatial pool kernel size, spatial
pool kernel size in order. By default pool_size is None,
then there would be no pooling used.
instantiation (string): supports two different instantiation method:
"dot_product": normalizing correlation matrix with L2.
"softmax": normalizing correlation matrix with Softmax.
zero_init_final_conv (bool): If true, zero initializing the final
convolution of the Non-local block.
zero_init_final_norm (bool):
If true, zero initializing the final batch norm of the Non-local
block.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(Nonlocal, self).__init__()
self.dim = dim
self.dim_inner = dim_inner
self.pool_size = pool_size
self.instantiation = instantiation
self.use_pool = (
False
if pool_size is None
else any((size > 1 for size in pool_size))
)
self.norm_eps = norm_eps
self.norm_momentum = norm_momentum
self._construct_nonlocal(
zero_init_final_conv, zero_init_final_norm, norm_module
)
def _construct_nonlocal(
self, zero_init_final_conv, zero_init_final_norm, norm_module
):
# Three convolution heads: theta, phi, and g.
self.conv_theta = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_phi = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
self.conv_g = nn.Conv3d(
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0
)
# Final convolution output.
self.conv_out = nn.Conv3d(
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0
)
# Zero initializing the final convolution output.
self.conv_out.zero_init = zero_init_final_conv
# TODO: change the name to `norm`
self.bn = norm_module(
num_features=self.dim,
eps=self.norm_eps,
momentum=self.norm_momentum,
)
# Zero initializing the final bn.
self.bn.transform_final_bn = zero_init_final_norm
# Optional to add the spatial-temporal pooling.
if self.use_pool:
self.pool = nn.MaxPool3d(
kernel_size=self.pool_size,
stride=self.pool_size,
padding=[0, 0, 0],
)
def forward(self, x):
x_identity = x
N, C, T, H, W = x.size()
theta = self.conv_theta(x)
# Perform temporal-spatial pooling to reduce the computation.
if self.use_pool:
x = self.pool(x)
phi = self.conv_phi(x)
g = self.conv_g(x)
theta = theta.view(N, self.dim_inner, -1)
phi = phi.view(N, self.dim_inner, -1)
g = g.view(N, self.dim_inner, -1)
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
# For original Non-local paper, there are two main ways to normalize
# the affinity tensor:
# 1) Softmax normalization (norm on exp).
# 2) dot_product normalization.
if self.instantiation == "softmax":
# Normalizing the affinity tensor theta_phi before softmax.
theta_phi = theta_phi * (self.dim_inner ** -0.5)
theta_phi = nn.functional.softmax(theta_phi, dim=2)
elif self.instantiation == "dot_product":
spatial_temporal_dim = theta_phi.shape[2]
theta_phi = theta_phi / spatial_temporal_dim
else:
raise NotImplementedError(
"Unknown norm type {}".format(self.instantiation)
)
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))
# (N, C, TxHxW) => (N, C, T, H, W).
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
p = self.conv_out(theta_phi_g)
p = self.bn(p)
return x_identity + p
|