Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
"""Custom operators.""" | |
import torch | |
import torch.nn as nn | |
class Swish(nn.Module): | |
"""Swish activation function: x * sigmoid(x).""" | |
def __init__(self): | |
super(Swish, self).__init__() | |
def forward(self, x): | |
return SwishEfficient.apply(x) | |
class SwishEfficient(torch.autograd.Function): | |
"""Swish activation function: x * sigmoid(x).""" | |
def forward(ctx, x): | |
result = x * torch.sigmoid(x) | |
ctx.save_for_backward(x) | |
return result | |
def backward(ctx, grad_output): | |
x = ctx.saved_variables[0] | |
sigmoid_x = torch.sigmoid(x) | |
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) | |
class SE(nn.Module): | |
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" | |
def _round_width(self, width, multiplier, min_width=8, divisor=8): | |
""" | |
Round width of filters based on width multiplier | |
Args: | |
width (int): the channel dimensions of the input. | |
multiplier (float): the multiplication factor. | |
min_width (int): the minimum width after multiplication. | |
divisor (int): the new width should be dividable by divisor. | |
""" | |
if not multiplier: | |
return width | |
width *= multiplier | |
min_width = min_width or divisor | |
width_out = max( | |
min_width, int(width + divisor / 2) // divisor * divisor | |
) | |
if width_out < 0.9 * width: | |
width_out += divisor | |
return int(width_out) | |
def __init__(self, dim_in, ratio, relu_act=True): | |
""" | |
Args: | |
dim_in (int): the channel dimensions of the input. | |
ratio (float): the channel reduction ratio for squeeze. | |
relu_act (bool): whether to use ReLU activation instead | |
of Swish (default). | |
divisor (int): the new width should be dividable by divisor. | |
""" | |
super(SE, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) | |
dim_fc = self._round_width(dim_in, ratio) | |
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True) | |
self.fc1_act = nn.ReLU() if relu_act else Swish() | |
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True) | |
self.fc2_sig = nn.Sigmoid() | |
def forward(self, x): | |
x_in = x | |
for module in self.children(): | |
x = module(x) | |
return x_in * x | |