File size: 4,479 Bytes
0caed3c 0004858 0caed3c f199ded 0caed3c df23063 0caed3c f199ded dc7f2ae df23063 0caed3c f199ded df23063 f199ded df23063 1e117bb dc7f2ae a424b9d 1e117bb a424b9d dc7f2ae a424b9d dc7f2ae a424b9d 1e117bb dc7f2ae a424b9d 1e117bb a424b9d dc7f2ae a424b9d 1e117bb a424b9d 1e117bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import argparse
import numpy as np
from skimage import color, io
import torch
import torch.nn.functional as F
from PIL import Image
from models import ColorEncoder, ColorUNet
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
def mkdirs(path):
if not os.path.exists(path):
os.makedirs(path)
def Lab2RGB_out(img_lab):
img_lab = img_lab.detach().cpu()
img_l = img_lab[:,:1,:,:]
img_ab = img_lab[:,1:,:,:]
img_l = img_l + 50
pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
return out
def RGB2Lab(inputs):
return color.rgb2lab(inputs)
def Normalize(inputs):
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2,0,1))
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Colorize manga images.")
parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
args = parser.parse_args()
device = "cuda"
ckpt = torch.load(args.model_checkpoint, map_location=lambda storage, loc: storage)
colorEncoder = ColorEncoder().to(device)
colorEncoder.load_state_dict(ckpt["colorEncoder"])
colorEncoder.eval()
colorUNet = ColorUNet().to(device)
colorUNet.load_state_dict(ckpt["colorUNet"])
colorUNet.eval()
reference_img = Image.open(args.reference_image).convert("RGB")
reference_img = np.array(reference_img).astype(np.float32) / 255.0 # Asegúrate de que la referencia esté en el rango [0, 1]
reference_img_lab = RGB2Lab(reference_img)
reference_img_lab = Normalize(reference_img_lab)
reference_img_lab = numpy2tensor(reference_img_lab)
reference_img_lab = reference_img_lab.to(device).unsqueeze(0)
for root, dirs, files in os.walk(args.input_folder):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
input_image_path = os.path.join(root, file)
img = Image.open(input_image_path).convert("RGB")
img = np.array(img).astype(np.float32) / 255.0 # Asegúrate de que la imagen de entrada esté en el rango [0, 1]
img_lab = RGB2Lab(img)
img_lab = Normalize(img_lab)
img_lab = numpy2tensor(img_lab)
img_lab = img_lab.to(device).unsqueeze(0)
with torch.no_grad():
img_resize = F.interpolate(img_lab / 110., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
img_L_resize = F.interpolate(img_resize[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
color_vector = colorEncoder(img_resize)
fake_ab = colorUNet((img_L_resize, color_vector))
fake_ab = F.interpolate(fake_ab, size=(img.shape[0], img.shape[1]), mode='bilinear', recompute_scale_factor=False, align_corners=False)
fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
fake_img = Lab2RGB_out(fake_img)
fake_img = (fake_img * 255).astype(np.uint8) # Convierte de nuevo a [0, 255]
relative_path = os.path.relpath(input_image_path, args.input_folder)
output_subfolder = os.path.join(args.output_folder, os.path.dirname(relative_path), 'color')
mkdirs(output_subfolder)
output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
io.imsave(output_image_path, fake_img)
print(f'Colored images have been saved to: {args.output_folder}')
|