Spaces:
Sleeping
Sleeping
| # from PIL import Image | |
| # from predictor import predictor | |
| # import torch | |
| # from src.utils import ( | |
| # uncenter_l, | |
| # tensor_lab2rgb, | |
| # ) | |
| import numpy as np | |
| import shutil | |
| import os | |
| import argparse | |
| import torch | |
| import glob | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from collections import OrderedDict | |
| from src.models.vit.config import load_config | |
| import torchvision.transforms as transforms | |
| from src.models.CNN.ColorVidNet import GeneralColorVidNet | |
| from src.models.vit.embed import GeneralEmbedModel | |
| from src.models.CNN.NonlocalNet import GeneralWarpNet | |
| from src.utils import ( | |
| TimeHandler, | |
| RGB2Lab, | |
| ToTensor, | |
| CenterPad, | |
| Normalize, | |
| LossHandler, | |
| WarpingLayer, | |
| uncenter_l, | |
| tensor_lab2rgb, | |
| print_num_params, | |
| SquaredPadding, | |
| UnpaddingSquare, | |
| ) | |
| from src.models.CNN.FrameColor import frame_colorization | |
| # cur_path="./horse2_ground.webp" | |
| # ref_path="./horse2_ref.jpg" | |
| # las_path="./horse2_ground.webp" | |
| weight_path="./ckp/12/" | |
| out_path = "./output_video/" | |
| root_path = "./EvalDataset" | |
| device="cuda" | |
| shutil.rmtree(out_path) | |
| os.mkdir(out_path) | |
| videos_list=os.listdir(root_path+"/clips/") | |
| # predictor_instance=predictor(model_path=weight_path,device=device) | |
| def load_params(ckpt_file): | |
| params = torch.load(ckpt_file) | |
| new_params = [] | |
| for key, value in params.items(): | |
| new_params.append((key, value)) | |
| return OrderedDict(new_params) | |
| embed_net=GeneralEmbedModel(pretrained_model="swin-small", device=device).to(device).eval() | |
| nonlocal_net = GeneralWarpNet(feature_channel=128).to(device).eval() | |
| colornet=GeneralColorVidNet(7).to(device).eval() | |
| embed_net.load_state_dict( | |
| load_params( | |
| (glob.glob(os.path.join(weight_path,"embed_net*.pth")))[-1] | |
| ),strict=False | |
| ) | |
| nonlocal_net.load_state_dict( | |
| load_params( | |
| (glob.glob(os.path.join(weight_path,"nonlocal_net*.pth")))[-1] | |
| ) | |
| ) | |
| colornet.load_state_dict( | |
| load_params( | |
| (glob.glob(os.path.join(weight_path,"colornet*.pth")))[-1] | |
| ) | |
| ) | |
| def custom_transform(listTrans,img): | |
| for trans in listTrans: | |
| if isinstance(trans,SquaredPadding): | |
| img,padding=trans(img,return_paddings=True) | |
| else: | |
| img=trans(img) | |
| return img.to(device),padding | |
| # def save_numpy(path:str,ts,module): | |
| # np_ar=ts.numpy() | |
| # np.save(path.replace(".jpg","")+"_"+module,np_ar) | |
| transformer=[ | |
| SquaredPadding(target_size=224), | |
| RGB2Lab(), | |
| ToTensor(), | |
| Normalize(), | |
| ] | |
| high_resolution=True | |
| center_padder=CenterPad((224,224)) | |
| with torch.no_grad(): | |
| for video_name in tqdm(videos_list): | |
| frames_list=os.listdir(root_path+"/clips/"+video_name) | |
| frames_list= sorted(frames_list) | |
| ref_path = root_path+"/ref/"+video_name+"/" | |
| ref_file = os.listdir(ref_path)[0] | |
| ref_path = ref_path + ref_file | |
| I_last_lab_predict = torch.zeros((1,3,224,224)).to(device) | |
| video_out_path = out_path+"/"+video_name+"/" | |
| os.mkdir(video_out_path) | |
| ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") | |
| I_reference_lab, I_reference_padding= custom_transform(transformer,center_padder(ref_frame_pil_rgb)) | |
| I_reference_lab=torch.unsqueeze(I_reference_lab,0) | |
| I_reference_l = I_reference_lab[:, 0:1, :, :] | |
| I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
| I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) | |
| features_B = embed_net(I_reference_rgb) | |
| for frame_name in frames_list: | |
| # current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") | |
| # ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") | |
| # last_frame_pil_rgb=Image.open(las_path).convert("RGB") | |
| #=================================using predictor but fail======================== | |
| # I_current_lab = predictor_instance.data_transform(current_frame_pil_rgb) | |
| # I_current_lab = torch.unsqueeze(I_current_lab,0) | |
| # I_current_l = I_current_lab[:, 0:1, :, :] | |
| # I_current_ab = I_current_lab[:, 1:3, :, :] | |
| # # print(I_current_l[0,0,112:122,112:122]) | |
| # # I_last_lab = predictor_instance.data_transform(last_frame_pil_rgb) | |
| # # I_last_lab = torch.unsqueeze(I_last_lab,0) | |
| # # I_last_l = I_last_lab[:, 0:1, :, :] | |
| # # I_last_ab = I_last_lab[:, 1:3, :, :] | |
| # I_current_lab_predict= predictor_instance(I_current_l=I_current_l,ref_img=ref_frame_pil_rgb,I_last_lab_predict=I_last_lab_predict) | |
| # I_current_l_predict=I_current_lab_predict[:, 0:1, :, :] | |
| # I_current_ab_predict=I_current_lab_predict[:, 1:3, :, :] | |
| # I_current_rgb_predict = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l_predict), I_current_ab_predict), dim=1)) | |
| # image_result2 = Image.fromarray((I_current_rgb_predict[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) | |
| # I_last_lab_predict = I_current_lab_predict | |
| #=================================using predictor but fail======================== | |
| current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") | |
| im_w,im_h=current_frame_pil_rgb.size | |
| # ref_frame_pil_rgb.show() | |
| I_current_lab,I_current_padding = custom_transform(transformer,current_frame_pil_rgb) | |
| I_current_lab=torch.unsqueeze(I_current_lab,0) | |
| I_current_l = I_current_lab[:, 0:1, :, :] | |
| I_current_ab = I_current_lab[:, 1:3, :, :] | |
| # save_numpy(video_out_path+"/"+frame_name,I_current_l,"current_I") | |
| # save_numpy(video_out_path+"/"+frame_name,I_reference_lab,"reference_lab") | |
| # save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"I_last_lab_predict") | |
| with torch.no_grad(): | |
| I_current_ab_predict,_ = frame_colorization( | |
| IA_l=I_current_l, | |
| IB_lab=I_reference_lab, | |
| IA_last_lab=I_last_lab_predict, | |
| features_B=features_B, | |
| embed_net=embed_net, | |
| colornet=colornet, | |
| nonlocal_net=nonlocal_net, | |
| luminance_noise=False, | |
| #temperature=1e-10, | |
| ) | |
| if high_resolution: | |
| high_lab=transforms.Compose([ | |
| SquaredPadding(target_size=max(im_h,im_w)), | |
| RGB2Lab(), | |
| ToTensor(), | |
| Normalize(), | |
| ]) | |
| # print(im_h) | |
| # print(im_w) | |
| high_lab_current = high_lab(current_frame_pil_rgb) | |
| high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device) | |
| high_l_current = high_lab_current[:, 0:1, :, :] | |
| high_ab_current = high_lab_current[:, 1:3, :, :] | |
| upsampler=torch.nn.Upsample(scale_factor=max(im_h,im_w)/224,mode="bilinear") | |
| high_ab_predict = upsampler(I_current_ab_predict) | |
| I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1)) | |
| else: | |
| I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1)) | |
| # I_predict_rgb = unpadder(I_predict_rgb,I_current_padding) | |
| image_result2 = Image.fromarray((I_predict_rgb[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) | |
| I_last_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
| # save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"result_lab") | |
| image_result2.save(video_out_path+"/"+frame_name) | |
| # image_result2.show() |