befozg's picture
fixed live demo app, converted network for onnx convertion, fixed code
0891b79
raw
history blame
10.6 kB
from matplotlib import pyplot as plt
# from shtools import shReconstructSignal
from torchvision import transforms, utils
# from torchvision.ops import SqueezeExcitation
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.nn as nn
import torch
import math
import cv2
import numpy as np
from .normalizer import PatchNormalizer, PatchedHarmonizer
from .util import rgb_to_lab, lab_to_rgb, lab_shift
# from shtools import *
# from color_converters import luv_to_rgb, rgb_to_luv
# from skimage import io, transform
'''
Input (256,512,3)
'''
def inpaint_bg(comp, mask, dim=[2, 3]):
"""
inpaint bg for ihd
Args:
comp (torch.float): [0:1]
mask (torch.float): [0:1]
"""
back = comp * (1-mask) # *255
sum = torch.sum(back, dim=dim) # (B, C)
num = torch.sum((1-mask), dim=dim) # (B, C)
mu = sum / (num)
mean = mu[:, :, None, None]
back = back + mask * mean
return back
class ConvTransposeUp(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=4, padding=1, stride=2, activation=None):
super().__init__(
nn.ConvTranspose2d(in_channels, out_channels,
kernel_size=kernel_size, padding=padding, stride=stride),
activation() if activation is not None else nn.Identity(),
)
class UpsampleShuffle(nn.Sequential):
def __init__(self, in_channels, out_channels, activation=True):
super().__init__(
nn.Conv2d(in_channels, out_channels * 4, kernel_size=1),
nn.GELU() if activation else nn.Identity(),
nn.PixelShuffle(2)
)
def reset_parameters(self):
init_subpixel(self[0].weight)
nn.init.zeros_(self[0].bias)
class UpsampleResize(nn.Sequential):
def __init__(self, in_channels, out_channels, out_size=None, activation=None, scale_factor=2., mode='bilinear'):
super().__init__(
nn.Upsample(scale_factor=scale_factor, mode=mode) if out_size is None else nn.Upsample(
out_size, mode=mode),
nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=1, padding=0),
activation() if activation is not None else nn.Identity(),
)
def conv_bn(in_, out_, kernel_size=3, stride=1, padding=1, activation=nn.ReLU, normalization=nn.InstanceNorm2d):
return nn.Sequential(
nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=padding),
normalization(out_) if normalization is not None else nn.Identity(),
activation(),
)
def init_subpixel(weight):
co, ci, h, w = weight.shape
co2 = co // 4
# initialize sub kernel
k = torch.empty([c02, ci, h, w])
nn.init.kaiming_uniform_(k)
# repeat 4 times
k = k.repeat_interleave(4, dim=0)
weight.data.copy_(k)
class DownsampleShuffle(nn.Sequential):
def __init__(self, in_channels):
assert in_channels % 4 == 0
super().__init__(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
nn.ReLU(),
nn.PixelUnshuffle(2)
)
def reset_parameters(self):
init_subpixel(self[0].weight)
nn.init.zeros_(self[0].bias)
def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True):
# conv layer with ELU activation function
pad = int(kernel_size/2)
if padding is False:
pad = 0
return nn.Sequential(
nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad),
nn.ELU(),
)
class Inference_Data(Dataset):
def __init__(self, img_path):
self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.input_img = cv2.resize(
self.input_img, (512, 256), interpolation=cv2.INTER_CUBIC)
self.to_tensor = transforms.ToTensor()
self.data_len = 1
def __getitem__(self, index):
self.tensor_img = self.to_tensor(self.input_img)
return self.tensor_img
def __len__(self):
return self.data_len
class MyAdaptiveMaxPool2d(nn.Module):
def __init__(self, sz=None):
super().__init__()
def forward(self, x):
inp_size = x.size()
return nn.functional.max_pool2d(input=x,
kernel_size=(inp_size[2], inp_size[3]))
class SEBlock(nn.Module):
def __init__(self, channel, reducation=8):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel//reducation),
nn.ReLU(inplace=True),
nn.Linear(channel//reducation, channel),
nn.Sigmoid())
def forward(self, x, aux_inp=None):
b, c, w, h = x.size()
def scale(x):
return (x - x.min()) / (x.max() - x.min() + 1e-8)
y1 = self.avg_pool(x).view(b, c)
y = self.fc(y1).view(b, c, 1, 1)
r = x*y
if aux_inp is not None:
aux_weitghts = MyAdaptiveMaxPool2d(
aux_inp.shape[-1]//8)(aux_inp)
aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True))
tmp = x*aux_weitghts
tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min())
r += tmp
return r
class ConvTransposeUp(nn.Sequential):
def __init__(self, in_channels, out_channels, norm, kernel_size=3, stride=2, padding=1, activation=None):
super().__init__(
nn.ConvTranspose2d(in_channels, out_channels,
# output_padding=output_padding, dilation=dilation
kernel_size=kernel_size, padding=padding, stride=stride,
),
norm(out_channels) if norm is not None else nn.Identity(),
activation() if activation is not None else nn.Identity(),
)
class SkipConnect(nn.Module):
"""docstring for RegionalSkipConnect"""
def __init__(self, channel):
super(SkipConnect, self).__init__()
self.rconv = nn.Conv2d(channel*2, channel, 3, padding=1, bias=False)
def forward(self, feature):
return F.relu(self.rconv(feature))
class AttentionBlock(nn.Module):
def __init__(self, in_channels):
super(AttentionBlock, self).__init__()
self.attn = nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
return self.attn(x)
class PatchHarmonizerBlock(nn.Module):
def __init__(self, in_channels=3, grid_count=5):
super(PatchHarmonizerBlock, self).__init__()
self.patch_harmonizer = PatchedHarmonizer(grid_count=grid_count)
self.head = conv_bn(in_channels*2, in_channels,
kernel_size=3, padding=1, normalization=None)
def forward(self, fg, bg, mask):
fg_harm, _ = self.patch_harmonizer(fg, bg, mask)
return self.head(torch.cat([fg, fg_harm], 1))
class PHNet(nn.Module):
def __init__(self, enc_sizes=[3, 16, 32, 64, 128, 256, 512], skips=True, grid_count=[10, 5, 1], init_weights=[0.5, 0.5], init_value=0.8):
super(PHNet, self).__init__()
self.skips = skips
self.feature_extractor = PatchHarmonizerBlock(
in_channels=enc_sizes[0], grid_count=grid_count[1])
self.encoder = nn.ModuleList([
conv_bn(enc_sizes[0], enc_sizes[1],
kernel_size=4, stride=2),
conv_bn(enc_sizes[1], enc_sizes[2],
kernel_size=3, stride=1),
conv_bn(enc_sizes[2], enc_sizes[3],
kernel_size=4, stride=2),
conv_bn(enc_sizes[3], enc_sizes[4],
kernel_size=3, stride=1),
conv_bn(enc_sizes[4], enc_sizes[5],
kernel_size=4, stride=2),
conv_bn(enc_sizes[5], enc_sizes[6],
kernel_size=3, stride=1),
])
dec_ins = enc_sizes[::-1]
dec_sizes = enc_sizes[::-1]
self.start_level = len(dec_sizes) - len(grid_count)
self.normalizers = nn.ModuleList([
PatchNormalizer(in_channels=dec_sizes[self.start_level+i], grid_count=count, weights=init_weights, eps=1e-7, init_value=init_value) for i, count in enumerate(grid_count)
])
self.decoder = nn.ModuleList([
ConvTransposeUp(
dec_ins[0], dec_sizes[1], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
ConvTransposeUp(
dec_ins[1], dec_sizes[2], norm=nn.BatchNorm2d, kernel_size=4, stride=2, activation=nn.LeakyReLU),
ConvTransposeUp(
dec_ins[2], dec_sizes[3], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU),
ConvTransposeUp(
dec_ins[3], dec_sizes[4], norm=None, kernel_size=4, stride=2, activation=nn.LeakyReLU),
ConvTransposeUp(
dec_ins[4], dec_sizes[5], norm=None, kernel_size=3, stride=1, activation=nn.LeakyReLU),
ConvTransposeUp(
dec_ins[5], 3, norm=None, kernel_size=4, stride=2, activation=None),
])
self.skip = nn.ModuleList([
SkipConnect(x) for x in dec_ins
])
self.SE_block = SEBlock(enc_sizes[6])
def forward(self, img, mask):
x = img
enc_outs = [x]
x_harm = self.feature_extractor(x*mask, x*(1-mask), mask)
# x = x_harm
masks = [mask]
for i, down_layer in enumerate(self.encoder):
x = down_layer(x)
scale_factor = 1. / (pow(2, 1 - i % 2))
masks.append(F.interpolate(masks[-1], scale_factor=scale_factor))
enc_outs.append(x)
x = self.SE_block(x, aux_inp=x_harm)
masks = masks[::-1]
for i, (up_layer, enc_out) in enumerate(zip(self.decoder, enc_outs[::-1])):
if i >= self.start_level:
enc_out = self.normalizers[i -
self.start_level](enc_out, enc_out, masks[i])
x = torch.cat([x, enc_out], 1)
x = self.skip[i](x)
x = up_layer(x)
harmonized = F.sigmoid(x)
return harmonized
def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False):
for module in modules:
attr = getattr(self, module, None)
if attr is not None:
attr.requires_grad_(value)