File size: 2,743 Bytes
517d380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright (C) 2020 NVIDIA Corporation.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
from functools import partial

import torch
import torch.nn as nn

from imaginaire.layers import Conv2dBlock


class NonLocal2dBlock(nn.Module):
    r"""Self attention Layer

    Args:
        in_channels (int): Number of channels in the input tensor.
        scale (bool, optional, default=True): If ``True``, scale the
            output by a learnable parameter.
        clamp (bool, optional, default=``False``): If ``True``, clamp the
            scaling parameter to (-1, 1).
        weight_norm_type (str, optional, default='none'):
            Type of weight normalization.
            ``'none'``, ``'spectral'``, ``'weight'``
            or ``'weight_demod'``.
    """

    def __init__(self,
                 in_channels,
                 scale=True,
                 clamp=False,
                 weight_norm_type='none'):
        super(NonLocal2dBlock, self).__init__()
        self.clamp = clamp
        self.gamma = nn.Parameter(torch.zeros(1)) if scale else 1.0
        self.in_channels = in_channels
        base_conv2d_block = partial(Conv2dBlock,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0,
                                    weight_norm_type=weight_norm_type)
        self.theta = base_conv2d_block(in_channels, in_channels // 8)
        self.phi = base_conv2d_block(in_channels, in_channels // 8)
        self.g = base_conv2d_block(in_channels, in_channels // 2)
        self.out_conv = base_conv2d_block(in_channels // 2, in_channels)
        self.softmax = nn.Softmax(dim=-1)
        self.max_pool = nn.MaxPool2d(2)

    def forward(self, x):
        r"""

        Args:
            x (tensor) : input feature maps (B X C X W X H)
        Returns:
            (tuple):
              - out (tensor) : self attention value + input feature
              - attention (tensor): B x N x N (N is Width*Height)
        """
        n, c, h, w = x.size()
        theta = self.theta(x).view(n, -1, h * w).permute(0, 2, 1)

        phi = self.phi(x)
        phi = self.max_pool(phi).view(n, -1, h * w // 4)

        energy = torch.bmm(theta, phi)
        attention = self.softmax(energy)

        g = self.g(x)
        g = self.max_pool(g).view(n, -1, h * w // 4)

        out = torch.bmm(g, attention.permute(0, 2, 1))
        out = out.view(n, c // 2, h, w)
        out = self.out_conv(out)

        if self.clamp:
            out = self.gamma.clamp(-1, 1) * out + x
        else:
            out = self.gamma * out + x
        return out