File size: 4,445 Bytes
712b45c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import numpy as np
import cv2
import torch
from deepfillv2 import network
import skimage

from config import GPU_DEVICE


# ----------------------------------------
#                 Network
# ----------------------------------------
def create_generator(opt):
    # Initialize the networks
    generator = network.GatedGenerator(opt)
    print("-- Generator is created! --")
    network.weights_init(
        generator, init_type=opt.init_type, init_gain=opt.init_gain
    )
    print("-- Initialized generator with %s type --" % opt.init_type)
    return generator


def create_discriminator(opt):
    # Initialize the networks
    discriminator = network.PatchDiscriminator(opt)
    print("-- Discriminator is created! --")
    network.weights_init(
        discriminator, init_type=opt.init_type, init_gain=opt.init_gain
    )
    print("-- Initialize discriminator with %s type --" % opt.init_type)
    return discriminator


def create_perceptualnet():
    # Get the first 15 layers of vgg16, which is conv3_3
    perceptualnet = network.PerceptualNet()
    print("-- Perceptual network is created! --")
    return perceptualnet


# ----------------------------------------
#             PATH processing
# ----------------------------------------
def text_readlines(filename):
    # Try to read a txt file and return a list.Return [] if there was a mistake.
    try:
        file = open(filename, "r")
    except IOError:
        error = []
        return error
    content = file.readlines()
    # This for loop deletes the EOF (like \n)
    for i in range(len(content)):
        content[i] = content[i][: len(content[i]) - 1]
    file.close()
    return content


def savetxt(name, loss_log):
    np_loss_log = np.array(loss_log)
    np.savetxt(name, np_loss_log)


def get_files(path, mask=False):
    # read a folder, return the complete path
    ret = []
    for root, dirs, files in os.walk(path):
        for filespath in files:
            if filespath != ".DS_Store":
                continue
            ret.append(os.path.join(root, filespath))
    return ret


def get_names(path):
    # read a folder, return the image name
    ret = []
    for root, dirs, files in os.walk(path):
        for filespath in files:
            ret.append(filespath)
    return ret


def text_save(content, filename, mode="a"):
    # save a list to a txt
    # Try to save a list variable in txt file.
    file = open(filename, mode)
    for i in range(len(content)):
        file.write(str(content[i]) + "\n")
    file.close()


def check_path(path):
    if not os.path.exists(path):
        os.makedirs(path)


# ----------------------------------------
#    Validation and Sample at training
# ----------------------------------------
def save_sample_png(

    sample_folder, sample_name, img_list, name_list, pixel_max_cnt=255

):
    # Save image one-by-one
    for i in range(len(img_list)):
        img = img_list[i]
        # Recover normalization: * 255 because last layer is sigmoid activated
        img = img * 255
        # Process img_copy and do not destroy the data of img
        img_copy = (
            img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].to("cpu").numpy()
        )
        img_copy = np.clip(img_copy, 0, pixel_max_cnt)
        img_copy = img_copy.astype(np.uint8)
        img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
        # Save to certain path
        save_img_path = os.path.join(sample_folder, sample_name)
        cv2.imwrite(save_img_path, img_copy)


def psnr(pred, target, pixel_max_cnt=255):
    mse = torch.mul(target - pred, target - pred)
    rmse_avg = (torch.mean(mse).item()) ** 0.5
    p = 20 * np.log10(pixel_max_cnt / rmse_avg)
    return p


def grey_psnr(pred, target, pixel_max_cnt=255):
    pred = torch.sum(pred, dim=0)
    target = torch.sum(target, dim=0)
    mse = torch.mul(target - pred, target - pred)
    rmse_avg = (torch.mean(mse).item()) ** 0.5
    p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
    return p


def ssim(pred, target):
    pred = pred.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
    target = target.clone().data.permute(0, 2, 3, 1).to(GPU_DEVICE).numpy()
    target = target[0]
    pred = pred[0]
    ssim = skimage.measure.compare_ssim(target, pred, multichannel=True)
    return ssim