File size: 4,600 Bytes
0caed3c
0004858
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f199ded
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
df23063
0caed3c
 
 
df23063
0caed3c
 
 
 
 
 
 
 
 
 
f199ded
 
 
 
 
 
df23063
0caed3c
f199ded
df23063
f199ded
df23063
 
 
 
 
 
 
 
 
f199ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df23063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import argparse
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet
from extractor.manga_panel_extractor import PanelExtractor

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def Lab2RGB_out(img_lab):
    img_lab = img_lab.detach().cpu()
    img_l = img_lab[:,:1,:,:]
    img_ab = img_lab[:,1:,:,:]
    img_l = img_l + 50
    pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
    out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
    return out

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    lab = np.concatenate((l, ab), 2)
    return lab.astype('float32')

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

def tensor2numpy(inputs):
    out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
    return out

def preprocessing(inputs):
    img_lab = Normalize(RGB2Lab(inputs))
    img = np.array(inputs, 'float32')
    img = numpy2tensor(img)
    img_lab = numpy2tensor(img_lab)
    return img.unsqueeze(0), img_lab.unsqueeze(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colorize manga images.")
    parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
    parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
    parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
    parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
    parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
    args = parser.parse_args()

    device = "cuda"

    ckpt = torch.load(args.model_checkpoint, map_location=lambda storage, loc: storage)

    colorEncoder = ColorEncoder().to(device)
    colorEncoder.load_state_dict(ckpt["colorEncoder"])
    colorEncoder.eval()

    colorUNet = ColorUNet().to(device)
    colorUNet.load_state_dict(ckpt["colorUNet"])
    colorUNet.eval()

    if args.no_extractor:
        # Colorize a single image without panel extraction
        img_path = args.input_folder
        ref_img_path = args.reference_image

        img1 = Image.open(img_path).convert("RGB")
        width, height = img1.size
        img2 = Image.open(ref_img_path).convert("RGB")

        img1, img1_lab = preprocessing(img1)
        img2, img2_lab = preprocessing(img2)

        img1 = img1.to(device)
        img1_lab = img1_lab.to(device)
        img2 = img2.to(device)
        img2_lab = img2_lab.to(device)

        with torch.no_grad():
            img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
            img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)

            color_vector = colorEncoder(img2_resize)
            fake_ab = colorUNet((img1_L_resize, color_vector))
            fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)

            fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
            fake_img = Lab2RGB_out(fake_img)

            out_folder = os.path.join(args.output_folder, 'color')
            mkdirs(out_folder)
            out_img_path = os.path.join(out_folder, 'colorized_image.png')
            io.imsave(out_img_path, fake_img)

        if panel_num == 1:
            out_folder = os.path.dirname(img_path)
            out_name = os.path.basename(img_path)
            out_name = os.path.splitext(out_name)[0]
            out_img_path = os.path.join(out_folder,'color',f'{out_name}_color.png')

            Image.fromarray(fake_imgs[0]).show()
            folder_path = os.path.join(out_folder, 'color')
            if not os.path.exists(folder_path):
                os.makedirs(folder_path)
            io.imsave(out_img_path, fake_imgs[0])
        else:
            panel_extractor.concatPanels(img_path, fake_imgs, masks, panel_masks)

        print(f'Colored images have been saved to: {os.path.join(test_dir_path, "color")}')