Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from copy import deepcopy | |
| import fvcore.nn.weight_init as weight_init | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .batch_norm import get_norm | |
| from .blocks import DepthwiseSeparableConv2d | |
| from .wrappers import Conv2d | |
| class ASPP(nn.Module): | |
| """ | |
| Atrous Spatial Pyramid Pooling (ASPP). | |
| """ | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| dilations, | |
| *, | |
| norm, | |
| activation, | |
| pool_kernel_size=None, | |
| dropout: float = 0.0, | |
| use_depthwise_separable_conv=False, | |
| ): | |
| """ | |
| Args: | |
| in_channels (int): number of input channels for ASPP. | |
| out_channels (int): number of output channels. | |
| dilations (list): a list of 3 dilations in ASPP. | |
| norm (str or callable): normalization for all conv layers. | |
| See :func:`layers.get_norm` for supported format. norm is | |
| applied to all conv layers except the conv following | |
| global average pooling. | |
| activation (callable): activation function. | |
| pool_kernel_size (tuple, list): the average pooling size (kh, kw) | |
| for image pooling layer in ASPP. If set to None, it always | |
| performs global average pooling. If not None, it must be | |
| divisible by the shape of inputs in forward(). It is recommended | |
| to use a fixed input feature size in training, and set this | |
| option to match this size, so that it performs global average | |
| pooling in training, and the size of the pooling window stays | |
| consistent in inference. | |
| dropout (float): apply dropout on the output of ASPP. It is used in | |
| the official DeepLab implementation with a rate of 0.1: | |
| https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa | |
| use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d | |
| for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. | |
| """ | |
| super(ASPP, self).__init__() | |
| assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) | |
| self.pool_kernel_size = pool_kernel_size | |
| self.dropout = dropout | |
| use_bias = norm == "" | |
| self.convs = nn.ModuleList() | |
| # conv 1x1 | |
| self.convs.append( | |
| Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, out_channels), | |
| activation=deepcopy(activation), | |
| ) | |
| ) | |
| weight_init.c2_xavier_fill(self.convs[-1]) | |
| # atrous convs | |
| for dilation in dilations: | |
| if use_depthwise_separable_conv: | |
| self.convs.append( | |
| DepthwiseSeparableConv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=dilation, | |
| dilation=dilation, | |
| norm1=norm, | |
| activation1=deepcopy(activation), | |
| norm2=norm, | |
| activation2=deepcopy(activation), | |
| ) | |
| ) | |
| else: | |
| self.convs.append( | |
| Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=dilation, | |
| dilation=dilation, | |
| bias=use_bias, | |
| norm=get_norm(norm, out_channels), | |
| activation=deepcopy(activation), | |
| ) | |
| ) | |
| weight_init.c2_xavier_fill(self.convs[-1]) | |
| # image pooling | |
| # We do not add BatchNorm because the spatial resolution is 1x1, | |
| # the original TF implementation has BatchNorm. | |
| if pool_kernel_size is None: | |
| image_pooling = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), | |
| ) | |
| else: | |
| image_pooling = nn.Sequential( | |
| nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), | |
| Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), | |
| ) | |
| weight_init.c2_xavier_fill(image_pooling[1]) | |
| self.convs.append(image_pooling) | |
| self.project = Conv2d( | |
| 5 * out_channels, | |
| out_channels, | |
| kernel_size=1, | |
| bias=use_bias, | |
| norm=get_norm(norm, out_channels), | |
| activation=deepcopy(activation), | |
| ) | |
| weight_init.c2_xavier_fill(self.project) | |
| def forward(self, x): | |
| size = x.shape[-2:] | |
| if self.pool_kernel_size is not None: | |
| if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: | |
| raise ValueError( | |
| "`pool_kernel_size` must be divisible by the shape of inputs. " | |
| "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) | |
| ) | |
| res = [] | |
| for conv in self.convs: | |
| res.append(conv(x)) | |
| res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) | |
| res = torch.cat(res, dim=1) | |
| res = self.project(res) | |
| res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res | |
| return res | |