import torch from torch import nn from torch.nn import functional as F from torch.nn import Parameter from deepfillv2.network_utils import * # ----------------------------------------------- # Normal ConvBlock # ----------------------------------------------- class Conv2dLayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type="zero", activation="elu", norm="none", sn=False, ): super(Conv2dLayer, self).__init__() # Initialize the padding scheme if pad_type == "reflect": self.pad = nn.ReflectionPad2d(padding) elif pad_type == "replicate": self.pad = nn.ReplicationPad2d(padding) elif pad_type == "zero": self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # Initialize the normalization type if norm == "bn": self.norm = nn.BatchNorm2d(out_channels) elif norm == "in": self.norm = nn.InstanceNorm2d(out_channels) elif norm == "ln": self.norm = LayerNorm(out_channels) elif norm == "none": self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # Initialize the activation funtion if activation == "relu": self.activation = nn.ReLU(inplace=True) elif activation == "lrelu": self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == "elu": self.activation = nn.ELU(inplace=True) elif activation == "selu": self.activation = nn.SELU(inplace=True) elif activation == "tanh": self.activation = nn.Tanh() elif activation == "sigmoid": self.activation = nn.Sigmoid() elif activation == "none": self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) # Initialize the convolution layers if sn: self.conv2d = SpectralNorm( nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) ) else: self.conv2d = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) def forward(self, x): x = self.pad(x) x = self.conv2d(x) if self.norm: x = self.norm(x) if self.activation: x = self.activation(x) return x class TransposeConv2dLayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type="zero", activation="lrelu", norm="none", sn=False, scale_factor=2, ): super(TransposeConv2dLayer, self).__init__() # Initialize the conv scheme self.scale_factor = scale_factor self.conv2d = Conv2dLayer( in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn, ) def forward(self, x): x = F.interpolate( x, scale_factor=self.scale_factor, mode="nearest", recompute_scale_factor=False, ) x = self.conv2d(x) return x # ----------------------------------------------- # Gated ConvBlock # ----------------------------------------------- class GatedConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type="reflect", activation="elu", norm="none", sn=False, ): super(GatedConv2d, self).__init__() # Initialize the padding scheme if pad_type == "reflect": self.pad = nn.ReflectionPad2d(padding) elif pad_type == "replicate": self.pad = nn.ReplicationPad2d(padding) elif pad_type == "zero": self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # Initialize the normalization type if norm == "bn": self.norm = nn.BatchNorm2d(out_channels) elif norm == "in": self.norm = nn.InstanceNorm2d(out_channels) elif norm == "ln": self.norm = LayerNorm(out_channels) elif norm == "none": self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # Initialize the activation funtion if activation == "relu": self.activation = nn.ReLU(inplace=True) elif activation == "lrelu": self.activation = nn.LeakyReLU(0.2, inplace=True) elif activation == "elu": self.activation = nn.ELU() elif activation == "selu": self.activation = nn.SELU(inplace=True) elif activation == "tanh": self.activation = nn.Tanh() elif activation == "sigmoid": self.activation = nn.Sigmoid() elif activation == "none": self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) # Initialize the convolution layers if sn: self.conv2d = SpectralNorm( nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) ) self.mask_conv2d = SpectralNorm( nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) ) else: self.conv2d = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) self.mask_conv2d = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, ) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): x = self.pad(x) conv = self.conv2d(x) mask = self.mask_conv2d(x) gated_mask = self.sigmoid(mask) if self.activation: conv = self.activation(conv) x = conv * gated_mask return x class TransposeGatedConv2d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, pad_type="zero", activation="lrelu", norm="none", sn=True, scale_factor=2, ): super(TransposeGatedConv2d, self).__init__() # Initialize the conv scheme self.scale_factor = scale_factor self.gated_conv2d = GatedConv2d( in_channels, out_channels, kernel_size, stride, padding, dilation, pad_type, activation, norm, sn, ) def forward(self, x): x = F.interpolate( x, scale_factor=self.scale_factor, mode="nearest", recompute_scale_factor=False, ) x = self.gated_conv2d(x) return x # ---------------------------------------- # Layer Norm # ---------------------------------------- class LayerNorm(nn.Module): def __init__(self, num_features, eps=1e-8, affine=True): super(LayerNorm, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps if self.affine: self.gamma = Parameter(torch.Tensor(num_features).uniform_()) self.beta = Parameter(torch.zeros(num_features)) def forward(self, x): # layer norm shape = [-1] + [1] * (x.dim() - 1) # for 4d input: [-1, 1, 1, 1] if x.size(0) == 1: # These two lines run much faster in pytorch 0.4 than the two lines listed below. mean = x.view(-1).mean().view(*shape) std = x.view(-1).std().view(*shape) else: mean = x.view(x.size(0), -1).mean(1).view(*shape) std = x.view(x.size(0), -1).std(1).view(*shape) x = (x - mean) / (std + self.eps) # if it is learnable if self.affine: shape = [1, -1] + [1] * ( x.dim() - 2 ) # for 4d input: [1, -1, 1, 1] x = x * self.gamma.view(*shape) + self.beta.view(*shape) return x # ----------------------------------------------- # SpectralNorm # ----------------------------------------------- def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): def __init__(self, module, name="weight", power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize( torch.mv(torch.t(w.view(height, -1).data), u.data) ) u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) class ContextualAttention(nn.Module): def __init__( self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True, device_ids=None, ): super(ContextualAttention, self).__init__() self.ksize = ksize self.stride = stride self.rate = rate self.fuse_k = fuse_k self.softmax_scale = softmax_scale self.fuse = fuse self.use_cuda = use_cuda self.device_ids = device_ids def forward(self, f, b, mask=None): """Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: f: Input feature to match (foreground). b: Input feature for match (background). mask: Input mask for b, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from b. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. Returns: torch.tensor: output """ # get shapes raw_int_fs = list(f.size()) # b*c*h*w raw_int_bs = list(b.size()) # b*c*h*w # extract patches from background with stride and rate kernel = 2 * self.rate # raw_w is extracted for reconstruction raw_w = extract_image_patches( b, ksizes=[kernel, kernel], strides=[self.rate * self.stride, self.rate * self.stride], rates=[1, 1], padding="same", ) # [N, C*k*k, L] # raw_shape: [N, C, k, k, L] [4, 192, 4, 4, 1024] raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1) raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] raw_w_groups = torch.split(raw_w, 1, dim=0) # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. f = F.interpolate( f, scale_factor=1.0 / self.rate, mode="nearest", recompute_scale_factor=False, ) b = F.interpolate( b, scale_factor=1.0 / self.rate, mode="nearest", recompute_scale_factor=False, ) int_fs = list(f.size()) # b*c*h*w int_bs = list(b.size()) f_groups = torch.split( f, 1, dim=0 ) # split tensors along the batch dimension # w shape: [N, C*k*k, L] w = extract_image_patches( b, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding="same", ) # w shape: [N, C, k, k, L] w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1) w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] w_groups = torch.split(w, 1, dim=0) # process mask mask = F.interpolate( mask, scale_factor=1.0 / self.rate, mode="nearest", recompute_scale_factor=False, ) int_ms = list(mask.size()) # m shape: [N, C*k*k, L] m = extract_image_patches( mask, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding="same", ) # m shape: [N, C, k, k, L] m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1) m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k] m = m[0] # m shape: [L, C, k, k] # mm shape: [L, 1, 1, 1] mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.0).to( torch.float32 ) mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1] y = [] offsets = [] k = self.fuse_k scale = ( self.softmax_scale ) # to fit the PyTorch tensor image value range fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k if self.use_cuda: fuse_weight = fuse_weight.cuda() for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): """ O => output channel as a conv filter I => input channel as a conv filter xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) """ # conv for compare escape_NaN = torch.FloatTensor([1e-4]) if self.use_cuda: escape_NaN = escape_NaN.cuda() wi = wi[0] # [L, C, k, k] max_wi = torch.sqrt( reduce_sum( torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True ) ) wi_normed = wi / max_wi # xi shape: [1, C, H, W], yi shape: [1, L, H, W] xi = same_padding( xi, [self.ksize, self.ksize], [1, 1], [1, 1] ) # xi: 1*c*H*W yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] # conv implementation for fuse scores to encourage large patches if self.fuse: # make all of depth to spatial resolution yi = yi.view( 1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3] ) # (B=1, I=1, H=32*32, W=32*32) yi = same_padding(yi, [k, k], [1, 1], [1, 1]) yi = F.conv2d( yi, fuse_weight, stride=1 ) # (B=1, C=1, H=32*32, W=32*32) yi = yi.contiguous().view( 1, int_bs[2], int_bs[3], int_fs[2], int_fs[3] ) # (B=1, 32, 32, 32, 32) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view( 1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3] ) yi = same_padding(yi, [k, k], [1, 1], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) yi = yi.contiguous().view( 1, int_bs[3], int_bs[2], int_fs[3], int_fs[2] ) yi = yi.permute(0, 2, 1, 4, 3).contiguous() yi = yi.view( 1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3] ) # (B=1, C=32*32, H=32, W=32) # softmax to match yi = yi * mm yi = F.softmax(yi * scale, dim=1) yi = yi * mm # [1, L, H, W] offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W if int_bs != int_fs: # Normalize the offset value to match foreground dimension times = float(int_fs[2] * int_fs[3]) / float( int_bs[2] * int_bs[3] ) offset = ((offset + 1).float() * times - 1).to(torch.int64) offset = torch.cat( [offset // int_fs[3], offset % int_fs[3]], dim=1 ) # 1*2*H*W # deconv for patch pasting wi_center = raw_wi[0] # yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding yi = ( F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4.0 ) # (B=1, C=128, H=64, W=64) y.append(yi) offsets.append(offset) y = torch.cat(y, dim=0) # back to the mini-batch y.contiguous().view(raw_int_fs) return y