Keiser41 commited on
Commit
a424b9d
·
1 Parent(s): 0809b65

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +40 -50
pintar.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  from models import ColorEncoder, ColorUNet
9
- from extractor.manga_panel_extractor import PanelExtractor
10
 
11
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
12
 
@@ -54,7 +53,6 @@ if __name__ == "__main__":
54
  parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
55
  parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
56
  parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
57
- parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
58
  args = parser.parse_args()
59
 
60
  device = "cuda"
@@ -69,51 +67,43 @@ if __name__ == "__main__":
69
  colorUNet.load_state_dict(ckpt["colorUNet"])
70
  colorUNet.eval()
71
 
72
- if args.no_extractor:
73
- # Colorize a single image without panel extraction
74
- img_path = args.input_folder
75
- ref_img_path = args.reference_image
76
-
77
- img1 = Image.open(ref_img_path).convert("RGB")
78
- width, height = img1.size
79
- img2 = Image.open(img_path).convert("RGB")
80
-
81
- img1, img1_lab = preprocessing(img1)
82
- img2, img2_lab = preprocessing(img2)
83
-
84
- img1 = img1.to(device)
85
- img1_lab = img1_lab.to(device)
86
- img2 = img2.to(device)
87
- img2_lab = img2_lab.to(device)
88
-
89
- with torch.no_grad():
90
- img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
91
- img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
92
-
93
- color_vector = colorEncoder(img2_resize)
94
- fake_ab = colorUNet((img1_L_resize, color_vector))
95
- fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
96
-
97
- fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
98
- fake_img = Lab2RGB_out(fake_img)
99
-
100
- out_folder = os.path.join(args.output_folder, 'color')
101
- mkdirs(out_folder)
102
- out_img_path = os.path.join(out_folder, 'colorized_image.png')
103
- io.imsave(out_img_path, fake_img)
104
-
105
- if panel_num == 1:
106
- out_folder = os.path.dirname(img_path)
107
- out_name = os.path.basename(img_path)
108
- out_name = os.path.splitext(out_name)[0]
109
- out_img_path = os.path.join(out_folder,'color',f'{out_name}_color.png')
110
-
111
- Image.fromarray(fake_imgs[0]).show()
112
- folder_path = os.path.join(out_folder, 'color')
113
- if not os.path.exists(folder_path):
114
- os.makedirs(folder_path)
115
- io.imsave(out_img_path, fake_imgs[0])
116
- else:
117
- panel_extractor.concatPanels(img_path, fake_imgs, masks, panel_masks)
118
-
119
- print(f'Colored images have been saved to: {os.path.join(test_dir_path, "color")}')
 
6
  import torch.nn.functional as F
7
  from PIL import Image
8
  from models import ColorEncoder, ColorUNet
 
9
 
10
  os.environ["CUDA_VISIBLE_DEVICES"] = '0'
11
 
 
53
  parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
54
  parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
55
  parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
 
56
  args = parser.parse_args()
57
 
58
  device = "cuda"
 
67
  colorUNet.load_state_dict(ckpt["colorUNet"])
68
  colorUNet.eval()
69
 
70
+ input_folder = args.input_folder
71
+ output_folder = args.output_folder
72
+ reference_image_path = args.reference_image
73
+
74
+ for root, dirs, files in os.walk(input_folder):
75
+ for file in files:
76
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
77
+ input_image_path = os.path.join(root, file)
78
+
79
+ img1 = Image.open(reference_image_path).convert("RGB")
80
+ width, height = img1.size
81
+ img2 = Image.open(input_image_path).convert("RGB")
82
+
83
+ img1, img1_lab = preprocessing(img1)
84
+ img2, img2_lab = preprocessing(img2)
85
+
86
+ img1 = img1.to(device)
87
+ img1_lab = img1_lab.to(device)
88
+ img2 = img2.to(device)
89
+ img2_lab = img2_lab.to(device)
90
+
91
+ with torch.no_grad():
92
+ img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
93
+ img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
94
+
95
+ color_vector = colorEncoder(img2_resize)
96
+ fake_ab = colorUNet((img1_L_resize, color_vector))
97
+ fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
98
+
99
+ fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
100
+ fake_img = Lab2RGB_out(fake_img)
101
+
102
+ relative_path = os.path.relpath(input_image_path, input_folder)
103
+ output_subfolder = os.path.join(output_folder, os.path.dirname(relative_path), 'color')
104
+ mkdirs(output_subfolder)
105
+ output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
106
+ io.imsave(output_image_path, fake_img)
107
+
108
+ print(f'Colored images have been saved to: {output_folder}')
109
+