# from PIL import Image # from predictor import predictor # import torch # from src.utils import ( # uncenter_l, # tensor_lab2rgb, # ) import numpy as np import shutil import os import argparse import torch import glob from tqdm import tqdm from PIL import Image from collections import OrderedDict from src.models.vit.config import load_config import torchvision.transforms as transforms from src.models.CNN.ColorVidNet import GeneralColorVidNet from src.models.vit.embed import GeneralEmbedModel from src.models.CNN.NonlocalNet import GeneralWarpNet from src.utils import ( TimeHandler, RGB2Lab, ToTensor, CenterPad, Normalize, LossHandler, WarpingLayer, uncenter_l, tensor_lab2rgb, print_num_params, SquaredPadding, UnpaddingSquare, ) from src.models.CNN.FrameColor import frame_colorization # cur_path="./horse2_ground.webp" # ref_path="./horse2_ref.jpg" # las_path="./horse2_ground.webp" weight_path="./ckp/12/" out_path = "./output_video/" root_path = "./EvalDataset" device="cuda" shutil.rmtree(out_path) os.mkdir(out_path) videos_list=os.listdir(root_path+"/clips/") # predictor_instance=predictor(model_path=weight_path,device=device) def load_params(ckpt_file): params = torch.load(ckpt_file) new_params = [] for key, value in params.items(): new_params.append((key, value)) return OrderedDict(new_params) embed_net=GeneralEmbedModel(pretrained_model="swin-small", device=device).to(device).eval() nonlocal_net = GeneralWarpNet(feature_channel=128).to(device).eval() colornet=GeneralColorVidNet(7).to(device).eval() embed_net.load_state_dict( load_params( (glob.glob(os.path.join(weight_path,"embed_net*.pth")))[-1] ),strict=False ) nonlocal_net.load_state_dict( load_params( (glob.glob(os.path.join(weight_path,"nonlocal_net*.pth")))[-1] ) ) colornet.load_state_dict( load_params( (glob.glob(os.path.join(weight_path,"colornet*.pth")))[-1] ) ) def custom_transform(listTrans,img): for trans in listTrans: if isinstance(trans,SquaredPadding): img,padding=trans(img,return_paddings=True) else: img=trans(img) return img.to(device),padding # def save_numpy(path:str,ts,module): # np_ar=ts.numpy() # np.save(path.replace(".jpg","")+"_"+module,np_ar) transformer=[ SquaredPadding(target_size=224), RGB2Lab(), ToTensor(), Normalize(), ] high_resolution=True center_padder=CenterPad((224,224)) with torch.no_grad(): for video_name in tqdm(videos_list): frames_list=os.listdir(root_path+"/clips/"+video_name) frames_list= sorted(frames_list) ref_path = root_path+"/ref/"+video_name+"/" ref_file = os.listdir(ref_path)[0] ref_path = ref_path + ref_file I_last_lab_predict = torch.zeros((1,3,224,224)).to(device) video_out_path = out_path+"/"+video_name+"/" os.mkdir(video_out_path) ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") I_reference_lab, I_reference_padding= custom_transform(transformer,center_padder(ref_frame_pil_rgb)) I_reference_lab=torch.unsqueeze(I_reference_lab,0) I_reference_l = I_reference_lab[:, 0:1, :, :] I_reference_ab = I_reference_lab[:, 1:3, :, :] I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) features_B = embed_net(I_reference_rgb) for frame_name in frames_list: # current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") # ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") # last_frame_pil_rgb=Image.open(las_path).convert("RGB") #=================================using predictor but fail======================== # I_current_lab = predictor_instance.data_transform(current_frame_pil_rgb) # I_current_lab = torch.unsqueeze(I_current_lab,0) # I_current_l = I_current_lab[:, 0:1, :, :] # I_current_ab = I_current_lab[:, 1:3, :, :] # # print(I_current_l[0,0,112:122,112:122]) # # I_last_lab = predictor_instance.data_transform(last_frame_pil_rgb) # # I_last_lab = torch.unsqueeze(I_last_lab,0) # # I_last_l = I_last_lab[:, 0:1, :, :] # # I_last_ab = I_last_lab[:, 1:3, :, :] # I_current_lab_predict= predictor_instance(I_current_l=I_current_l,ref_img=ref_frame_pil_rgb,I_last_lab_predict=I_last_lab_predict) # I_current_l_predict=I_current_lab_predict[:, 0:1, :, :] # I_current_ab_predict=I_current_lab_predict[:, 1:3, :, :] # I_current_rgb_predict = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l_predict), I_current_ab_predict), dim=1)) # image_result2 = Image.fromarray((I_current_rgb_predict[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) # I_last_lab_predict = I_current_lab_predict #=================================using predictor but fail======================== current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") im_w,im_h=current_frame_pil_rgb.size # ref_frame_pil_rgb.show() I_current_lab,I_current_padding = custom_transform(transformer,current_frame_pil_rgb) I_current_lab=torch.unsqueeze(I_current_lab,0) I_current_l = I_current_lab[:, 0:1, :, :] I_current_ab = I_current_lab[:, 1:3, :, :] # save_numpy(video_out_path+"/"+frame_name,I_current_l,"current_I") # save_numpy(video_out_path+"/"+frame_name,I_reference_lab,"reference_lab") # save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"I_last_lab_predict") with torch.no_grad(): I_current_ab_predict,_ = frame_colorization( IA_l=I_current_l, IB_lab=I_reference_lab, IA_last_lab=I_last_lab_predict, features_B=features_B, embed_net=embed_net, colornet=colornet, nonlocal_net=nonlocal_net, luminance_noise=False, #temperature=1e-10, ) if high_resolution: high_lab=transforms.Compose([ SquaredPadding(target_size=max(im_h,im_w)), RGB2Lab(), ToTensor(), Normalize(), ]) # print(im_h) # print(im_w) high_lab_current = high_lab(current_frame_pil_rgb) high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device) high_l_current = high_lab_current[:, 0:1, :, :] high_ab_current = high_lab_current[:, 1:3, :, :] upsampler=torch.nn.Upsample(scale_factor=max(im_h,im_w)/224,mode="bilinear") high_ab_predict = upsampler(I_current_ab_predict) I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1)) else: I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1)) # I_predict_rgb = unpadder(I_predict_rgb,I_current_padding) image_result2 = Image.fromarray((I_predict_rgb[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) I_last_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) # save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"result_lab") image_result2.save(video_out_path+"/"+frame_name) # image_result2.show()