HealthiVert-GAN / models /inpaint_tools.py
ZhangqiSJTU's picture
Upload 96 files
7d21475 verified
# this code was taken and only slightly addapted from https://github.com/DAA233/generative-inpainting-pytorch in September 2019
import os
import torch
import numpy as np
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
def flow_to_image(flow):
"""
Transfer flow map to image.
Part of code forked from flownet.
"""
out = []
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
maxrad = -1
for i in range(flow.shape[0]):
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
u[idxunknow] = 0
v[idxunknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(maxrad, np.max(rad))
u = u / (maxrad + np.finfo(float).eps)
v = v / (maxrad + np.finfo(float).eps)
img = compute_color(u, v)
out.append(img)
return np.float32(np.uint8(out))
def pt_flow_to_image(flow):
"""
Transfer flow map to image.
Part of code forked from flownet.
"""
out = []
maxu = torch.tensor(-999)
maxv = torch.tensor(-999)
minu = torch.tensor(999)
minv = torch.tensor(999)
maxrad = torch.tensor(-1)
if torch.cuda.is_available():
maxu = maxu.cuda()
maxv = maxv.cuda()
minu = minu.cuda()
minv = minv.cuda()
maxrad = maxrad.cuda()
for i in range(flow.shape[0]):
u = flow[i, 0, :, :]
v = flow[i, 1, :, :]
idxunknow = (torch.abs(u) > 1e7) + (torch.abs(v) > 1e7)
u[idxunknow] = 0
v[idxunknow] = 0
maxu = torch.max(maxu, torch.max(u))
minu = torch.min(minu, torch.min(u))
maxv = torch.max(maxv, torch.max(v))
minv = torch.min(minv, torch.min(v))
rad = torch.sqrt((u ** 2 + v ** 2).float()).to(torch.int64)
maxrad = torch.max(maxrad, torch.max(rad))
u = u / (maxrad + torch.finfo(torch.float32).eps)
v = v / (maxrad + torch.finfo(torch.float32).eps)
# TODO: change the following to pytorch
img = pt_compute_color(u, v)
out.append(img)
return torch.stack(out, dim=0)
def highlight_flow(flow):
"""Convert flow into middlebury color code image.
"""
out = []
s = flow.shape
for i in range(flow.shape[0]):
img = np.ones((s[1], s[2], 3)) * 144.
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
for h in range(s[1]):
for w in range(s[1]):
ui = u[h, w]
vi = v[h, w]
img[ui, vi, :] = 255.
out.append(img)
return np.float32(np.uint8(out))
def pt_highlight_flow(flow):
"""
Convert flow into middlebury color code image.
"""
out = []
s = flow.shape
for i in range(flow.shape[0]):
img = np.ones((s[1], s[2], 3)) * 144.
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
for h in range(s[1]):
for w in range(s[1]):
ui = u[h, w]
vi = v[h, w]
img[ui, vi, :] = 255.
out.append(img)
return np.float32(np.uint8(out))
def compute_color(u, v):
h, w = u.shape
img = np.zeros([h, w, 3]) # changed channel size from 3 to 1!!!!
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
# colorwheel = COLORWHEEL
colorwheel = make_color_wheel()
# colorwheel size is 55,3
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u ** 2 + v ** 2)
a = np.arctan2(-v, -u) / np.pi
fk = (a + 1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols + 1] = 1
f = fk - k0
for i in range(np.size(colorwheel, 1)):
tmp = colorwheel[:, i]
col0 = tmp[k0 - 1] / 255
col1 = tmp[k1 - 1] / 255
col = (1 - f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1 - rad[idx] * (1 - col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
# convert RGB image to grayscale
#gray_img = np.zeros((img.shape[0], img.shape[1]))
#gray_img = 0.11*img[:,:,0] + 0.56*img[:,:,1] + 0.33*img[:,:,2]
return img
def pt_compute_color(u, v):
h, w = u.shape
img = torch.zeros([3, h, w])
if torch.cuda.is_available():
img = img.cuda()
nanIdx = (torch.isnan(u) + torch.isnan(v)) != 0
u[nanIdx] = 0.
v[nanIdx] = 0.
# colorwheel = COLORWHEEL
colorwheel = pt_make_color_wheel()
if torch.cuda.is_available():
colorwheel = colorwheel.cuda()
ncols = colorwheel.size()[0]
rad = torch.sqrt((u ** 2 + v ** 2).to(torch.float32))
a = torch.atan2(-v.to(torch.float32), -u.to(torch.float32)) / np.pi
fk = (a + 1) / 2 * (ncols - 1) + 1
k0 = torch.floor(fk).to(torch.int64)
k1 = k0 + 1
k1[k1 == ncols + 1] = 1
f = fk - k0.to(torch.float32)
for i in range(colorwheel.size()[1]):
tmp = colorwheel[:, i]
col0 = tmp[k0 - 1]
col1 = tmp[k1 - 1]
col = (1 - f) * col0 + f * col1
idx = rad <= 1. / 255.
col[idx] = 1 - rad[idx] * (1 - col[idx])
notidx = (idx != 0)
col[notidx] *= 0.75
img[i, :, :] = col * (1 - nanIdx).to(torch.float32)
return img
def make_color_wheel():
import scipy.misc
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3]) # changed 3 to 1 again!!!!!
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
colorwheel[col:col + YG, 1] = 255
col += YG
# GC
colorwheel[col:col + GC, 1] = 255
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
colorwheel[col:col + CB, 2] = 255
col += CB
# BM
colorwheel[col:col + BM, 2] = 255
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col + MR, 0] = 255
return colorwheel
def pt_make_color_wheel():
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = torch.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 1.
colorwheel[0:RY, 1] = torch.arange(0, RY, dtype=torch.float32) / RY
col += RY
# YG
colorwheel[col:col + YG, 0] = 1. - (torch.arange(0, YG, dtype=torch.float32) / YG)
colorwheel[col:col + YG, 1] = 1.
col += YG
# GC
colorwheel[col:col + GC, 1] = 1.
colorwheel[col:col + GC, 2] = torch.arange(0, GC, dtype=torch.float32) / GC
col += GC
# CB
colorwheel[col:col + CB, 1] = 1. - (torch.arange(0, CB, dtype=torch.float32) / CB)
colorwheel[col:col + CB, 2] = 1.
col += CB
# BM
colorwheel[col:col + BM, 2] = 1.
colorwheel[col:col + BM, 0] = torch.arange(0, BM, dtype=torch.float32) / BM
col += BM
# MR
colorwheel[col:col + MR, 2] = 1. - (torch.arange(0, MR, dtype=torch.float32) / MR)
colorwheel[col:col + MR, 0] = 1.
return colorwheel