Keiser41 commited on
Commit
f199ded
·
1 Parent(s): df23063

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +41 -109
pintar.py CHANGED
@@ -20,7 +20,7 @@ def Lab2RGB_out(img_lab):
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
- out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
@@ -49,26 +49,17 @@ def preprocessing(inputs):
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
- device = "cuda"
53
-
54
- parser = argparse.ArgumentParser()
55
- parser.add_argument("--path", type=str, default=None, help="path of input image")
56
- parser.add_argument("--size", type=int, default=None)
57
- parser.add_argument("--ckpt", type=str, default=None, help="path of model weight")
58
- parser.add_argument("-ne", "--no_extractor", action='store_true', help="Do not segment the manga panels.")
59
-
60
  args = parser.parse_args()
61
 
62
- if args.path:
63
- test_dir_path = args.path
64
- if args.size:
65
- imgsize = args.size
66
- if args.ckpt:
67
- ckpt_path = args.ckpt
68
- if args.no_extractor:
69
- no_extractor = args.no_extractor
70
 
71
- ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
72
 
73
  colorEncoder = ColorEncoder().to(device)
74
  colorEncoder.load_state_dict(ckpt["colorEncoder"])
@@ -78,97 +69,38 @@ if __name__ == "__main__":
78
  colorUNet.load_state_dict(ckpt["colorUNet"])
79
  colorUNet.eval()
80
 
81
- imgs = []
82
- imgs_lab = []
83
-
84
- while 1:
85
- print(f'make sure both manga image and reference images are under this path {test_dir_path}')
86
- img_path = input("please input the name of image needed to be colorized (with file extension): ")
87
- img_path = os.path.join(test_dir_path, img_path)
88
- img_name = os.path.basename(img_path)
89
- img_name = os.path.splitext(img_name)[0]
90
-
91
- if no_extractor:
92
- ref_img_path = os.path.join(test_dir_path, input(f"Enter the reference image path: "))
93
-
94
- img1 = Image.open(img_path).convert("RGB")
95
- width, height = img1.size
96
- img2 = Image.open(ref_img_path).convert("RGB")
97
-
98
- img1, img1_lab = preprocessing(img1)
99
- img2, img2_lab = preprocessing(img2)
100
-
101
- img1 = img1.to(device)
102
- img1_lab = img1_lab.to(device)
103
- img2 = img2.to(device)
104
- img2_lab = img2_lab.to(device)
105
-
106
- with torch.no_grad():
107
- img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear',
108
- recompute_scale_factor=False, align_corners=False)
109
- img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(imgsize, imgsize), mode='bilinear',
110
- recompute_scale_factor=False, align_corners=False)
111
-
112
- color_vector = colorEncoder(img2_resize)
113
-
114
- fake_ab = colorUNet((img1_L_resize, color_vector))
115
- fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear',
116
- recompute_scale_factor=False, align_corners=False)
117
-
118
- fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
119
- fake_img = Lab2RGB_out(fake_img)
120
-
121
- out_folder = os.path.dirname(img_path)
122
- out_name = os.path.basename(img_path)
123
- out_name = os.path.splitext(out_name)[0]
124
- out_img_path = os.path.join(out_folder, 'color', f'{out_name}_color.png')
125
-
126
- # show image
127
- Image.fromarray(fake_img).show()
128
- # save image
129
- folder_path = os.path.join(out_folder, 'color')
130
- if not os.path.exists(folder_path):
131
- os.makedirs(folder_path)
132
- io.imsave(out_img_path, fake_img)
133
-
134
- continue
135
-
136
- panel_extractor = PanelExtractor(min_pct_panel=5, max_pct_panel=90)
137
- panels, masks, panel_masks = panel_extractor.extract(img_path)
138
- panel_num = len(panels)
139
-
140
- ref_img_paths = []
141
- print("Please enter the name of the reference image in order according to the number prompts on the picture")
142
- for i in range(panel_num):
143
- ref_img_path = os.path.join(test_dir_path, input(f"{i+1}/{panel_num} reference image:"))
144
- ref_img_paths.append(ref_img_path)
145
-
146
- fake_imgs = []
147
- for i in range(panel_num):
148
- img1 = Image.fromarray(panels[i]).convert("RGB")
149
- width, height = img1.size
150
- img2 = Image.open(ref_img_paths[i]).convert("RGB")
151
-
152
- img1, img1_lab = preprocessing(img1)
153
- img2, img2_lab = preprocessing(img2)
154
-
155
- img1 = img1.to(device)
156
- img1_lab = img1_lab.to(device)
157
- img2 = img2.to(device)
158
- img2_lab = img2_lab.to(device)
159
-
160
- with torch.no_grad():
161
- img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
162
- img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False)
163
-
164
- color_vector = colorEncoder(img2_resize)
165
-
166
- fake_ab = colorUNet((img1_L_resize, color_vector))
167
- fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
168
-
169
- fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1)
170
- fake_img = Lab2RGB_out(fake_img)
171
- fake_imgs.append(fake_img)
172
 
173
  if panel_num == 1:
174
  out_folder = os.path.dirname(img_path)
 
20
  img_ab = img_lab[:,1:,:,:]
21
  img_l = img_l + 50
22
  pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
23
+ out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
24
  return out
25
 
26
  def RGB2Lab(inputs):
 
49
  return img.unsqueeze(0), img_lab.unsqueeze(0)
50
 
51
  if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Colorize manga images.")
53
+ parser.add_argument("-i", "--input_folder", type=str, required=True, help="Path to the input folder containing manga images.")
54
+ parser.add_argument("-r", "--reference_image", type=str, required=True, help="Path to the reference image for colorization.")
55
+ parser.add_argument("-ckpt", "--model_checkpoint", type=str, required=True, help="Path to the model checkpoint file.")
56
+ parser.add_argument("-o", "--output_folder", type=str, required=True, help="Path to the output folder where colorized images will be saved.")
57
+ parser.add_argument("-ne", "--no_extractor", action="store_true", help="Do not segment the manga panels.")
 
 
58
  args = parser.parse_args()
59
 
60
+ device = "cuda"
 
 
 
 
 
 
 
61
 
62
+ ckpt = torch.load(args.model_checkpoint, map_location=lambda storage, loc: storage)
63
 
64
  colorEncoder = ColorEncoder().to(device)
65
  colorEncoder.load_state_dict(ckpt["colorEncoder"])
 
69
  colorUNet.load_state_dict(ckpt["colorUNet"])
70
  colorUNet.eval()
71
 
72
+ if args.no_extractor:
73
+ # Colorize a single image without panel extraction
74
+ img_path = args.input_folder
75
+ ref_img_path = args.reference_image
76
+
77
+ img1 = Image.open(img_path).convert("RGB")
78
+ width, height = img1.size
79
+ img2 = Image.open(ref_img_path).convert("RGB")
80
+
81
+ img1, img1_lab = preprocessing(img1)
82
+ img2, img2_lab = preprocessing(img2)
83
+
84
+ img1 = img1.to(device)
85
+ img1_lab = img1_lab.to(device)
86
+ img2 = img2.to(device)
87
+ img2_lab = img2_lab.to(device)
88
+
89
+ with torch.no_grad():
90
+ img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
91
+ img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
92
+
93
+ color_vector = colorEncoder(img2_resize)
94
+ fake_ab = colorUNet((img1_L_resize, color_vector))
95
+ fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
96
+
97
+ fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
98
+ fake_img = Lab2RGB_out(fake_img)
99
+
100
+ out_folder = os.path.join(args.output_folder, 'color')
101
+ mkdirs(out_folder)
102
+ out_img_path = os.path.join(out_folder, 'colorized_image.png')
103
+ io.imsave(out_img_path, fake_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if panel_num == 1:
106
  out_folder = os.path.dirname(img_path)