File size: 3,190 Bytes
0caed3c
 
 
f6df16f
0caed3c
 
f6df16f
0caed3c
 
063c371
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
f6df16f
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
df23063
0caed3c
 
063c371
 
 
 
 
 
 
 
 
 
 
0caed3c
f199ded
df23063
f6df16f
 
 
 
 
 
 
 
df23063
 
 
 
 
 
 
 
 
f6df16f
 
 
 
 
063c371
f6df16f
 
 
 
063c371
f6df16f
 
 
063c371
f6df16f
063c371
f6df16f
 
063c371
f6df16f
 
063c371
f6df16f
 
 
 
063c371
f6df16f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import numpy as np
from skimage import color, io

import torch
import torch.nn.functional as F

from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def Lab2RGB_out(img_lab):
    img_lab = img_lab.detach().cpu()
    img_l = img_lab[:,:1,:,:]
    img_ab = img_lab[:,1:,:,:]
    img_l = img_l + 50
    pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
    out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
    return out

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    lab = np.concatenate((l, ab), 2)
    return lab.astype('float32')

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

def tensor2numpy(inputs):
    out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
    return out

def preprocessing(inputs):
    img_lab = Normalize(RGB2Lab(inputs))
    img = np.array(inputs, 'float32')
    img = numpy2tensor(img)
    img_lab = numpy2tensor(img_lab)
    return img.unsqueeze(0), img_lab.unsqueeze(0)

if __name__ == "__main__":
    device = "cuda"

    # Specify the paths here
    img_path = 'path/to/your/input/image.jpg'
    ckpt_path = 'path/to/your/model_checkpoint.pt'
    reference_image_path = 'path/to/your/reference/image.jpg'

    imgsize = 256

    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

    colorEncoder = ColorEncoder().to(device)
    colorEncoder.load_state_dict(ckpt["colorEncoder"])
    colorEncoder.eval()

    colorUNet = ColorUNet().to(device)
    colorUNet.load_state_dict(ckpt["colorUNet"])
    colorUNet.eval()

    img_name = os.path.splitext(os.path.basename(img_path))[0]
    img1 = Image.open(img_path).convert("RGB")
    width, height = img1.size
    img1, img1_lab = preprocessing(img1)
    img2, img2_lab = preprocessing(Image.open(reference_image_path).convert("RGB"))

    img1 = img1.to(device)
    img1_lab = img1_lab.to(device)
    img2 = img2.to(device)
    img2_lab = img2_lab.to(device)

    with torch.no_grad():
        img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
        img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)

        color_vector = colorEncoder(img2_resize)

        fake_ab = colorUNet((img1_L_resize, color_vector))
        fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)

        fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
        fake_img = Lab2RGB_out(fake_img)

        out_folder = os.path.dirname(img_path)
        mkdirs(out_folder)
        out_img_path = os.path.join(out_folder, f'{img_name}_color.png')
        io.imsave(out_img_path, fake_img)

    print(f'Colored image has been saved to {out_img_path}.')