File size: 4,445 Bytes
712b45c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import numpy as np
import cv2
import torch
from deepfillv2 import network
import skimage
from config import GPU_DEVICE
# ----------------------------------------
# Network
# ----------------------------------------
def create_generator(opt):
# Initialize the networks
generator = network.GatedGenerator(opt)
print("-- Generator is created! --")
network.weights_init(
generator, init_type=opt.init_type, init_gain=opt.init_gain
)
print("-- Initialized generator with %s type --" % opt.init_type)
return generator
def create_discriminator(opt):
# Initialize the networks
discriminator = network.PatchDiscriminator(opt)
print("-- Discriminator is created! --")
network.weights_init(
discriminator, init_type=opt.init_type, init_gain=opt.init_gain
)
print("-- Initialize discriminator with %s type --" % opt.init_type)
return discriminator
def create_perceptualnet():
# Get the first 15 layers of vgg16, which is conv3_3
perceptualnet = network.PerceptualNet()
print("-- Perceptual network is created! --")
return perceptualnet
# ----------------------------------------
# PATH processing
# ----------------------------------------
def text_readlines(filename):
# Try to read a txt file and return a list.Return [] if there was a mistake.
try:
file = open(filename, "r")
except IOError:
error = []
return error
content = file.readlines()
# This for loop deletes the EOF (like \n)
for i in range(len(content)):
content[i] = content[i][: len(content[i]) - 1]
file.close()
return content
def savetxt(name, loss_log):
np_loss_log = np.array(loss_log)
np.savetxt(name, np_loss_log)
def get_files(path, mask=False):
# read a folder, return the complete path
ret = []
for root, dirs, files in os.walk(path):
for filespath in files:
if filespath != ".DS_Store":
continue
ret.append(os.path.join(root, filespath))
return ret
def get_names(path):
# read a folder, return the image name
ret = []
for root, dirs, files in os.walk(path):
for filespath in files:
ret.append(filespath)
return ret
def text_save(content, filename, mode="a"):
# save a list to a txt
# Try to save a list variable in txt file.
file = open(filename, mode)
for i in range(len(content)):
file.write(str(content[i]) + "\n")
file.close()
def check_path(path):
if not os.path.exists(path):
os.makedirs(path)
# ----------------------------------------
# Validation and Sample at training
# ----------------------------------------
def save_sample_png(
sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255
):
# Save image one-by-one
for i in range(len(img_list)):
img = img_list[i]
# Recover normalization: * 255 because last layer is sigmoid activated
img = img * 255
# Process img_copy and do not destroy the data of img
img_copy = (
img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
)
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
img_copy = img_copy.astype(np.uint8)
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
# Save to certain path
save_img_path = os.path.join(sample_folder, sample_name)
cv2.imwrite(save_img_path, img_copy)
def psnr(pred, target, pixel_max_cnt=255):
mse = torch.mul(target - pred, target - pred)
rmse_avg = (torch.mean(mse).item()) ** 0.5
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
return p
def grey_psnr(pred, target, pixel_max_cnt=255):
pred = torch.sum(pred, dim=0)
target = torch.sum(target, dim=0)
mse = torch.mul(target - pred, target - pred)
rmse_avg = (torch.mean(mse).item()) ** 0.5
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
return p
def ssim(pred, target):
pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
target = target[0]
pred = pred[0]
ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
return ssim
|