Spaces:
Running
Running
File size: 4,709 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from .conv_module import ConvModule
class DepthwiseSeparableConvModule(nn.Module):
"""Depthwise separable convolution module.
See https://arxiv.org/pdf/1704.04861.pdf for details.
This module can replace a ConvModule with the conv block replaced by two
conv block: depthwise conv block and pointwise conv block. The depthwise
conv block contains depthwise-conv/norm/activation layers. The pointwise
conv block contains pointwise-conv/norm/activation layers. It should be
noted that there will be norm/activation layer in the depthwise conv block
if `norm_cfg` and `act_cfg` are specified.
Args:
in_channels (int): Number of channels in the input feature map.
Same as that in ``nn._ConvNd``.
out_channels (int): Number of channels produced by the convolution.
Same as that in ``nn._ConvNd``.
kernel_size (int | tuple[int]): Size of the convolving kernel.
Same as that in ``nn._ConvNd``.
stride (int | tuple[int]): Stride of the convolution.
Same as that in ``nn._ConvNd``. Default: 1.
padding (int | tuple[int]): Zero-padding added to both sides of
the input. Same as that in ``nn._ConvNd``. Default: 0.
dilation (int | tuple[int]): Spacing between kernel elements.
Same as that in ``nn._ConvNd``. Default: 1.
norm_cfg (dict): Default norm config for both depthwise ConvModule and
pointwise ConvModule. Default: None.
act_cfg (dict): Default activation config for both depthwise ConvModule
and pointwise ConvModule. Default: dict(type='ReLU').
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
'default', it will be the same as `norm_cfg`. Default: 'default'.
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
'default', it will be the same as `act_cfg`. Default: 'default'.
kwargs (optional): Other shared arguments for depthwise and pointwise
ConvModule. See ConvModule for ref.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
norm_cfg: Optional[Dict] = None,
act_cfg: Dict = dict(type='ReLU'),
dw_norm_cfg: Union[Dict, str] = 'default',
dw_act_cfg: Union[Dict, str] = 'default',
pw_norm_cfg: Union[Dict, str] = 'default',
pw_act_cfg: Union[Dict, str] = 'default',
**kwargs):
super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution
self.depthwise_conv = ConvModule(
in_channels,
in_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
norm_cfg=dw_norm_cfg, # type: ignore
act_cfg=dw_act_cfg, # type: ignore
**kwargs)
self.pointwise_conv = ConvModule(
in_channels,
out_channels,
1,
norm_cfg=pw_norm_cfg, # type: ignore
act_cfg=pw_act_cfg, # type: ignore
**kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x)
x = self.pointwise_conv(x)
return x
|