File size: 4,479 Bytes
0caed3c
0004858
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f199ded
0caed3c
 
 
 
 
 
 
 
 
 
 
 
 
df23063
0caed3c
 
 
f199ded
 
 
 
 
dc7f2ae
df23063
0caed3c
f199ded
df23063
f199ded
df23063
 
 
 
 
 
 
 
 
1e117bb
dc7f2ae
 
 
 
 
a424b9d
1e117bb
a424b9d
 
 
 
dc7f2ae
 
 
 
 
 
a424b9d
 
dc7f2ae
 
a424b9d
1e117bb
 
dc7f2ae
a424b9d
1e117bb
a424b9d
dc7f2ae
a424b9d
1e117bb
 
a424b9d
 
 
 
1e117bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import argparse
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

def mkdirs(path):
    if not os.path.exists(path):
        os.makedirs(path)

def Lab2RGB_out(img_lab):
    img_lab = img_lab.detach().cpu()
    img_l = img_lab[:,:1,:,:]
    img_ab = img_lab[:,1:,:,:]
    img_l = img_l + 50
    pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
    out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
    return out

def RGB2Lab(inputs):
    return color.rgb2lab(inputs)

def Normalize(inputs):
    l = inputs[:, :, 0:1]
    ab = inputs[:, :, 1:3]
    l = l - 50
    lab = np.concatenate((l, ab), 2)
    return lab.astype('float32')

def numpy2tensor(inputs):
    out = torch.from_numpy(inputs.transpose(2,0,1))
    return out

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colorize manga images.")
    parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
    parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
    parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
    parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
    parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
    args = parser.parse_args()

    device = "cuda"

    ckpt = torch.load(args.model_checkpoint, map_location=lambda storage, loc: storage)

    colorEncoder = ColorEncoder().to(device)
    colorEncoder.load_state_dict(ckpt["colorEncoder"])
    colorEncoder.eval()

    colorUNet = ColorUNet().to(device)
    colorUNet.load_state_dict(ckpt["colorUNet"])
    colorUNet.eval()

    reference_img = Image.open(args.reference_image).convert("RGB")
    reference_img = np.array(reference_img).astype(np.float32) / 255.0  # Asegúrate de que la referencia esté en el rango [0, 1]
    reference_img_lab = RGB2Lab(reference_img)
    reference_img_lab = Normalize(reference_img_lab)
    reference_img_lab = numpy2tensor(reference_img_lab)
    reference_img_lab = reference_img_lab.to(device).unsqueeze(0)

    for root, dirs, files in os.walk(args.input_folder):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
                input_image_path = os.path.join(root, file)

                img = Image.open(input_image_path).convert("RGB")
                img = np.array(img).astype(np.float32) / 255.0  # Asegúrate de que la imagen de entrada esté en el rango [0, 1]
                img_lab = RGB2Lab(img)
                img_lab = Normalize(img_lab)
                img_lab = numpy2tensor(img_lab)
                img_lab = img_lab.to(device).unsqueeze(0)

                with torch.no_grad():
                    img_resize = F.interpolate(img_lab / 110., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
                    img_L_resize = F.interpolate(img_resize[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)

                    color_vector = colorEncoder(img_resize)
                    fake_ab = colorUNet((img_L_resize, color_vector))
                    fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear', recompute_scale_factor=False, align_corners=False)

                    fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
                    fake_img = Lab2RGB_out(fake_img)
                    fake_img = (fake_img * 255).astype(np.uint8)  # Convierte de nuevo a [0, 255]

                    relative_path = os.path.relpath(input_image_path, args.input_folder)
                    output_subfolder = os.path.join(args.output_folder, os.path.dirname(relative_path), 'color')
                    mkdirs(output_subfolder)
                    output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
                    io.imsave(output_image_path, fake_img)

    print(f'Colored images have been saved to: {args.output_folder}')