# Original from: https://github.com/xavysp/TEED # TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3 # with a Slightly modification # LDC parameters: # 155665 # TED > 58K import torch import torch.nn as nn import torch.nn.functional as F from ...modeling_utils import ModelMixin """ smish_function and Smish script based on: Wang, Xueliang, Honge Ren, and Achuan Wang. "Smish: A Novel Activation Function for Deep Learning Methods. " Electronics 11.4 (2022): 540. smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) """ @torch.jit.script def smish_function(input): """ Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) See additional documentation for mish class. """ return input * torch.tanh(torch.log(1 + torch.sigmoid(input))) class Smish(nn.Module): """ Applies the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) Shape: - Input: (N, *) where * means, any number of additional dimensions - Output: (N, *), same shape as the input Examples: >>> m = Mish() >>> input = torch.randn(2) >>> output = m(input) Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html """ def __init__(self): """ Init method. """ super().__init__() def forward(self, input): """ Forward pass of the function. """ return smish_function(input) def weight_init(m): if isinstance(m, (nn.Conv2d,)): torch.nn.init.xavier_normal_(m.weight, gain=1.0) if m.bias is not None: torch.nn.init.zeros_(m.bias) # for fusion layer if isinstance(m, (nn.ConvTranspose2d,)): torch.nn.init.xavier_normal_(m.weight, gain=1.0) if m.bias is not None: torch.nn.init.zeros_(m.bias) class CoFusion(nn.Module): # from LDC def __init__(self, in_ch, out_ch): super(CoFusion, self).__init__() self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64 self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32 self.relu = nn.ReLU() self.norm_layer1 = nn.GroupNorm(4, 32) # before 64 def forward(self, x): # fusecat = torch.cat(x, dim=1) attn = self.relu(self.norm_layer1(self.conv1(x))) attn = F.softmax(self.conv3(attn), dim=1) return ((x * attn).sum(1)).unsqueeze(1) class CoFusion2(nn.Module): # TEDv14-3 def __init__(self, in_ch, out_ch): super(CoFusion2, self).__init__() self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1) # before 64 self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, stride=1, padding=1) # before 64 instead of 32 self.smish = Smish() # nn.ReLU(inplace=True) def forward(self, x): attn = self.conv1(self.smish(x)) attn = self.conv3(self.smish(attn)) # before , )dim=1) return ((x * attn).sum(1)).unsqueeze(1) class DoubleFusion(nn.Module): # TED fusion before the final edge map prediction def __init__(self, in_ch, out_ch): super(DoubleFusion, self).__init__() self.DWconv1 = nn.Conv2d(in_ch, in_ch * 8, kernel_size=3, stride=1, padding=1, groups=in_ch) # before 64 self.PSconv1 = nn.PixelShuffle(1) self.DWconv2 = nn.Conv2d(24, 24 * 1, kernel_size=3, stride=1, padding=1, groups=24) # before 64 instead of 32 self.AF = Smish() # XAF() #nn.Tanh()# XAF() # # Smish()# def forward(self, x): attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352] attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352] return smish_function(((attn2 + attn).sum(1)).unsqueeze(1)) # TED best res class _DenseLayer(nn.Sequential): def __init__(self, input_features, out_features): super(_DenseLayer, self).__init__() ( self.add_module( "conv1", nn.Conv2d( input_features, out_features, kernel_size=3, stride=1, padding=2, bias=True, ), ), ) (self.add_module("smish1", Smish()),) self.add_module( "conv2", nn.Conv2d(out_features, out_features, kernel_size=3, stride=1, bias=True), ) def forward(self, x): x1, x2 = x new_features = super(_DenseLayer, self).forward(smish_function(x1)) # F.relu() return 0.5 * (new_features + x2), x2 class _DenseBlock(nn.Sequential): def __init__(self, num_layers, input_features, out_features): super(_DenseBlock, self).__init__() for i in range(num_layers): layer = _DenseLayer(input_features, out_features) self.add_module("denselayer%d" % (i + 1), layer) input_features = out_features class UpConvBlock(nn.Module): def __init__(self, in_features, up_scale): super(UpConvBlock, self).__init__() self.up_factor = 2 self.constant_features = 16 layers = self.make_deconv_layers(in_features, up_scale) assert layers is not None, layers self.features = nn.Sequential(*layers) def make_deconv_layers(self, in_features, up_scale): layers = [] all_pads = [0, 0, 1, 3, 7] for i in range(up_scale): kernel_size = 2**up_scale pad = all_pads[up_scale] # kernel_size-1 out_features = self.compute_out_features(i, up_scale) layers.append(nn.Conv2d(in_features, out_features, 1)) layers.append(Smish()) layers.append(nn.ConvTranspose2d(out_features, out_features, kernel_size, stride=2, padding=pad)) in_features = out_features return layers def compute_out_features(self, idx, up_scale): return 1 if idx == up_scale - 1 else self.constant_features def forward(self, x): return self.features(x) class SingleConvBlock(nn.Module): def __init__(self, in_features, out_features, stride, use_ac=False): super(SingleConvBlock, self).__init__() self.use_ac = use_ac self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, bias=True) if self.use_ac: self.smish = Smish() def forward(self, x): x = self.conv(x) if self.use_ac: return self.smish(x) else: return x class DoubleConvBlock(nn.Module): def __init__(self, in_features, mid_features, out_features=None, stride=1, use_act=True): super(DoubleConvBlock, self).__init__() self.use_act = use_act if out_features is None: out_features = mid_features self.conv1 = nn.Conv2d(in_features, mid_features, 3, padding=1, stride=stride) self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) self.smish = Smish() # nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.smish(x) x = self.conv2(x) if self.use_act: x = self.smish(x) return x class TEED(ModelMixin): """Definition of Tiny and Efficient Edge Detector model """ def __init__(self): super(TEED, self).__init__() self.block_1 = DoubleConvBlock( 3, 16, 16, stride=2, ) self.block_2 = DoubleConvBlock(16, 32, use_act=False) self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # skip1 connection, see fig. 2 self.side_1 = SingleConvBlock(16, 32, 2) # skip2 connection, see fig. 2 self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1) # USNet self.up_block_1 = UpConvBlock(16, 1) self.up_block_2 = UpConvBlock(32, 1) self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1) self.block_cat = DoubleFusion(3, 3) # TEED: DoubleFusion self.apply(weight_init) def slice(self, tensor, slice_shape): t_shape = tensor.shape img_h, img_w = slice_shape if img_w != t_shape[-1] or img_h != t_shape[2]: new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False) else: new_tensor = tensor # tensor[..., :height, :width] return new_tensor def resize_input(self, tensor): t_shape = tensor.shape if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0: img_w = ((t_shape[3] // 8) + 1) * 8 img_h = ((t_shape[2] // 8) + 1) * 8 new_tensor = F.interpolate(tensor, size=(img_h, img_w), mode="bicubic", align_corners=False) else: new_tensor = tensor return new_tensor def crop_bdcn(data1, h, w, crop_h, crop_w): # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN _, _, h1, w1 = data1.size() assert h <= h1 and w <= w1 data = data1[:, :, crop_h : crop_h + h, crop_w : crop_w + w] return data def forward(self, x, single_test=False): assert x.ndim == 4, x.shape # supose the image size is 352x352 # Block 1 block_1 = self.block_1(x) # [8,16,176,176] block_1_side = self.side_1(block_1) # 16 [8,32,88,88] # Block 2 block_2 = self.block_2(block_1) # 32 # [8,32,176,176] block_2_down = self.maxpool(block_2) # [8,32,88,88] block_2_add = block_2_down + block_1_side # [8,32,88,88] # Block 3 block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88] # upsampling blocks out_1 = self.up_block_1(block_1) out_2 = self.up_block_2(block_2) out_3 = self.up_block_3(block_3) results = [out_1, out_2, out_3] # concatenate multiscale outputs block_cat = torch.cat(results, dim=1) # Bx6xHxW block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion results.append(block_cat) return results