Spaces:
Sleeping
Sleeping
# from PIL import Image | |
# from predictor import predictor | |
# import torch | |
# from src.utils import ( | |
# uncenter_l, | |
# tensor_lab2rgb, | |
# ) | |
import numpy as np | |
import shutil | |
import os | |
import argparse | |
import torch | |
import glob | |
from tqdm import tqdm | |
from PIL import Image | |
from collections import OrderedDict | |
from src.models.vit.config import load_config | |
import torchvision.transforms as transforms | |
from src.models.CNN.ColorVidNet import GeneralColorVidNet | |
from src.models.vit.embed import GeneralEmbedModel | |
from src.models.CNN.NonlocalNet import GeneralWarpNet | |
from src.utils import ( | |
TimeHandler, | |
RGB2Lab, | |
ToTensor, | |
CenterPad, | |
Normalize, | |
LossHandler, | |
WarpingLayer, | |
uncenter_l, | |
tensor_lab2rgb, | |
print_num_params, | |
SquaredPadding, | |
UnpaddingSquare, | |
) | |
from src.models.CNN.FrameColor import frame_colorization | |
# cur_path="./horse2_ground.webp" | |
# ref_path="./horse2_ref.jpg" | |
# las_path="./horse2_ground.webp" | |
weight_path="./ckp/12/" | |
out_path = "./output_video/" | |
root_path = "./EvalDataset" | |
device="cuda" | |
shutil.rmtree(out_path) | |
os.mkdir(out_path) | |
videos_list=os.listdir(root_path+"/clips/") | |
# predictor_instance=predictor(model_path=weight_path,device=device) | |
def load_params(ckpt_file): | |
params = torch.load(ckpt_file) | |
new_params = [] | |
for key, value in params.items(): | |
new_params.append((key, value)) | |
return OrderedDict(new_params) | |
embed_net=GeneralEmbedModel(pretrained_model="swin-small", device=device).to(device).eval() | |
nonlocal_net = GeneralWarpNet(feature_channel=128).to(device).eval() | |
colornet=GeneralColorVidNet(7).to(device).eval() | |
embed_net.load_state_dict( | |
load_params( | |
(glob.glob(os.path.join(weight_path,"embed_net*.pth")))[-1] | |
),strict=False | |
) | |
nonlocal_net.load_state_dict( | |
load_params( | |
(glob.glob(os.path.join(weight_path,"nonlocal_net*.pth")))[-1] | |
) | |
) | |
colornet.load_state_dict( | |
load_params( | |
(glob.glob(os.path.join(weight_path,"colornet*.pth")))[-1] | |
) | |
) | |
def custom_transform(listTrans,img): | |
for trans in listTrans: | |
if isinstance(trans,SquaredPadding): | |
img,padding=trans(img,return_paddings=True) | |
else: | |
img=trans(img) | |
return img.to(device),padding | |
# def save_numpy(path:str,ts,module): | |
# np_ar=ts.numpy() | |
# np.save(path.replace(".jpg","")+"_"+module,np_ar) | |
transformer=[ | |
SquaredPadding(target_size=224), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize(), | |
] | |
high_resolution=True | |
center_padder=CenterPad((224,224)) | |
with torch.no_grad(): | |
for video_name in tqdm(videos_list): | |
frames_list=os.listdir(root_path+"/clips/"+video_name) | |
frames_list= sorted(frames_list) | |
ref_path = root_path+"/ref/"+video_name+"/" | |
ref_file = os.listdir(ref_path)[0] | |
ref_path = ref_path + ref_file | |
I_last_lab_predict = torch.zeros((1,3,224,224)).to(device) | |
video_out_path = out_path+"/"+video_name+"/" | |
os.mkdir(video_out_path) | |
ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") | |
I_reference_lab, I_reference_padding= custom_transform(transformer,center_padder(ref_frame_pil_rgb)) | |
I_reference_lab=torch.unsqueeze(I_reference_lab,0) | |
I_reference_l = I_reference_lab[:, 0:1, :, :] | |
I_reference_ab = I_reference_lab[:, 1:3, :, :] | |
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) | |
features_B = embed_net(I_reference_rgb) | |
for frame_name in frames_list: | |
# current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") | |
# ref_frame_pil_rgb=Image.open(ref_path).convert("RGB") | |
# last_frame_pil_rgb=Image.open(las_path).convert("RGB") | |
#=================================using predictor but fail======================== | |
# I_current_lab = predictor_instance.data_transform(current_frame_pil_rgb) | |
# I_current_lab = torch.unsqueeze(I_current_lab,0) | |
# I_current_l = I_current_lab[:, 0:1, :, :] | |
# I_current_ab = I_current_lab[:, 1:3, :, :] | |
# # print(I_current_l[0,0,112:122,112:122]) | |
# # I_last_lab = predictor_instance.data_transform(last_frame_pil_rgb) | |
# # I_last_lab = torch.unsqueeze(I_last_lab,0) | |
# # I_last_l = I_last_lab[:, 0:1, :, :] | |
# # I_last_ab = I_last_lab[:, 1:3, :, :] | |
# I_current_lab_predict= predictor_instance(I_current_l=I_current_l,ref_img=ref_frame_pil_rgb,I_last_lab_predict=I_last_lab_predict) | |
# I_current_l_predict=I_current_lab_predict[:, 0:1, :, :] | |
# I_current_ab_predict=I_current_lab_predict[:, 1:3, :, :] | |
# I_current_rgb_predict = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l_predict), I_current_ab_predict), dim=1)) | |
# image_result2 = Image.fromarray((I_current_rgb_predict[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) | |
# I_last_lab_predict = I_current_lab_predict | |
#=================================using predictor but fail======================== | |
current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB") | |
im_w,im_h=current_frame_pil_rgb.size | |
# ref_frame_pil_rgb.show() | |
I_current_lab,I_current_padding = custom_transform(transformer,current_frame_pil_rgb) | |
I_current_lab=torch.unsqueeze(I_current_lab,0) | |
I_current_l = I_current_lab[:, 0:1, :, :] | |
I_current_ab = I_current_lab[:, 1:3, :, :] | |
# save_numpy(video_out_path+"/"+frame_name,I_current_l,"current_I") | |
# save_numpy(video_out_path+"/"+frame_name,I_reference_lab,"reference_lab") | |
# save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"I_last_lab_predict") | |
with torch.no_grad(): | |
I_current_ab_predict,_ = frame_colorization( | |
IA_l=I_current_l, | |
IB_lab=I_reference_lab, | |
IA_last_lab=I_last_lab_predict, | |
features_B=features_B, | |
embed_net=embed_net, | |
colornet=colornet, | |
nonlocal_net=nonlocal_net, | |
luminance_noise=False, | |
#temperature=1e-10, | |
) | |
if high_resolution: | |
high_lab=transforms.Compose([ | |
SquaredPadding(target_size=max(im_h,im_w)), | |
RGB2Lab(), | |
ToTensor(), | |
Normalize(), | |
]) | |
# print(im_h) | |
# print(im_w) | |
high_lab_current = high_lab(current_frame_pil_rgb) | |
high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device) | |
high_l_current = high_lab_current[:, 0:1, :, :] | |
high_ab_current = high_lab_current[:, 1:3, :, :] | |
upsampler=torch.nn.Upsample(scale_factor=max(im_h,im_w)/224,mode="bilinear") | |
high_ab_predict = upsampler(I_current_ab_predict) | |
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1)) | |
else: | |
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1)) | |
# I_predict_rgb = unpadder(I_predict_rgb,I_current_padding) | |
image_result2 = Image.fromarray((I_predict_rgb[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8)) | |
I_last_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1) | |
# save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"result_lab") | |
image_result2.save(video_out_path+"/"+frame_name) | |
# image_result2.show() |