Keiser41 commited on
Commit
1e117bb
·
1 Parent(s): a424b9d

Update pintar.py

Browse files
Files changed (1) hide show
  1. pintar.py +17 -25
pintar.py CHANGED
@@ -67,43 +67,35 @@ if __name__ == "__main__":
67
  colorUNet.load_state_dict(ckpt["colorUNet"])
68
  colorUNet.eval()
69
 
70
- input_folder = args.input_folder
71
- output_folder = args.output_folder
72
- reference_image_path = args.reference_image
 
73
 
74
- for root, dirs, files in os.walk(input_folder):
75
  for file in files:
76
  if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
77
  input_image_path = os.path.join(root, file)
78
 
79
- img1 = Image.open(reference_image_path).convert("RGB")
80
- width, height = img1.size
81
- img2 = Image.open(input_image_path).convert("RGB")
82
-
83
- img1, img1_lab = preprocessing(img1)
84
- img2, img2_lab = preprocessing(img2)
85
-
86
- img1 = img1.to(device)
87
- img1_lab = img1_lab.to(device)
88
- img2 = img2.to(device)
89
- img2_lab = img2_lab.to(device)
90
 
91
  with torch.no_grad():
92
- img2_resize = F.interpolate(img2 / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
93
- img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
94
 
95
- color_vector = colorEncoder(img2_resize)
96
- fake_ab = colorUNet((img1_L_resize, color_vector))
97
- fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False)
98
 
99
- fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1)
100
  fake_img = Lab2RGB_out(fake_img)
101
 
102
- relative_path = os.path.relpath(input_image_path, input_folder)
103
- output_subfolder = os.path.join(output_folder, os.path.dirname(relative_path), 'color')
104
  mkdirs(output_subfolder)
105
  output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
106
  io.imsave(output_image_path, fake_img)
107
 
108
- print(f'Colored images have been saved to: {output_folder}')
109
-
 
67
  colorUNet.load_state_dict(ckpt["colorUNet"])
68
  colorUNet.eval()
69
 
70
+ reference_img = Image.open(args.reference_image).convert("RGB")
71
+ reference_img, reference_img_lab = preprocessing(reference_img)
72
+ reference_img = reference_img.to(device)
73
+ reference_img_lab = reference_img_lab.to(device)
74
 
75
+ for root, dirs, files in os.walk(args.input_folder):
76
  for file in files:
77
  if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')):
78
  input_image_path = os.path.join(root, file)
79
 
80
+ img, img_lab = preprocessing(Image.open(input_image_path).convert("RGB"))
81
+ img = img.to(device)
82
+ img_lab = img_lab.to(device)
 
 
 
 
 
 
 
 
83
 
84
  with torch.no_grad():
85
+ img_resize = F.interpolate(img / 255., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
86
+ img_L_resize = F.interpolate(img_lab[:, :1, :, :] / 50., size=(256, 256), mode='bilinear', recompute_scale_factor=False, align_corners=False)
87
 
88
+ color_vector = colorEncoder(img_resize)
89
+ fake_ab = colorUNet((img_L_resize, color_vector))
90
+ fake_ab = F.interpolate(fake_ab * 110, size=(img.size(2), img.size(3)), mode='bilinear', recompute_scale_factor=False, align_corners=False)
91
 
92
+ fake_img = torch.cat((img_lab[:, :1, :, :], fake_ab), 1)
93
  fake_img = Lab2RGB_out(fake_img)
94
 
95
+ relative_path = os.path.relpath(input_image_path, args.input_folder)
96
+ output_subfolder = os.path.join(args.output_folder, os.path.dirname(relative_path), 'color')
97
  mkdirs(output_subfolder)
98
  output_image_path = os.path.join(output_subfolder, f'{os.path.splitext(os.path.basename(input_image_path))[0]}_colorized.png')
99
  io.imsave(output_image_path, fake_img)
100
 
101
+ print(f'Colored images have been saved to: {args.output_folder}')