import torch.nn as nn import torch.nn.functional as F class DUC(nn.Module): ''' INPUT: inplanes, planes, upscale_factor OUTPUT: (planes // 4)* ht * wd ''' def __init__(self, inplanes, planes, upscale_factor=2): super(DUC, self).__init__() self.conv = nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, bias=False) self.bn = nn.BatchNorm2d(planes) self.relu = nn.ReLU() self.pixel_shuffle = nn.PixelShuffle(upscale_factor) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) x = self.pixel_shuffle(x) return x