SwinTExCo / video_predictor.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
8.16 kB
# from PIL import Image
# from predictor import predictor
# import torch
# from src.utils import (
# uncenter_l,
# tensor_lab2rgb,
# )
import numpy as np
import shutil
import os
import argparse
import torch
import glob
from tqdm import tqdm
from PIL import Image
from collections import OrderedDict
from src.models.vit.config import load_config
import torchvision.transforms as transforms
from src.models.CNN.ColorVidNet import GeneralColorVidNet
from src.models.vit.embed import GeneralEmbedModel
from src.models.CNN.NonlocalNet import GeneralWarpNet
from src.utils import (
TimeHandler,
RGB2Lab,
ToTensor,
CenterPad,
Normalize,
LossHandler,
WarpingLayer,
uncenter_l,
tensor_lab2rgb,
print_num_params,
SquaredPadding,
UnpaddingSquare,
)
from src.models.CNN.FrameColor import frame_colorization
# cur_path="./horse2_ground.webp"
# ref_path="./horse2_ref.jpg"
# las_path="./horse2_ground.webp"
weight_path="./ckp/12/"
out_path = "./output_video/"
root_path = "./EvalDataset"
device="cuda"
shutil.rmtree(out_path)
os.mkdir(out_path)
videos_list=os.listdir(root_path+"/clips/")
# predictor_instance=predictor(model_path=weight_path,device=device)
def load_params(ckpt_file):
params = torch.load(ckpt_file)
new_params = []
for key, value in params.items():
new_params.append((key, value))
return OrderedDict(new_params)
embed_net=GeneralEmbedModel(pretrained_model="swin-small", device=device).to(device).eval()
nonlocal_net = GeneralWarpNet(feature_channel=128).to(device).eval()
colornet=GeneralColorVidNet(7).to(device).eval()
embed_net.load_state_dict(
load_params(
(glob.glob(os.path.join(weight_path,"embed_net*.pth")))[-1]
),strict=False
)
nonlocal_net.load_state_dict(
load_params(
(glob.glob(os.path.join(weight_path,"nonlocal_net*.pth")))[-1]
)
)
colornet.load_state_dict(
load_params(
(glob.glob(os.path.join(weight_path,"colornet*.pth")))[-1]
)
)
def custom_transform(listTrans,img):
for trans in listTrans:
if isinstance(trans,SquaredPadding):
img,padding=trans(img,return_paddings=True)
else:
img=trans(img)
return img.to(device),padding
# def save_numpy(path:str,ts,module):
# np_ar=ts.numpy()
# np.save(path.replace(".jpg","")+"_"+module,np_ar)
transformer=[
SquaredPadding(target_size=224),
RGB2Lab(),
ToTensor(),
Normalize(),
]
high_resolution=True
center_padder=CenterPad((224,224))
with torch.no_grad():
for video_name in tqdm(videos_list):
frames_list=os.listdir(root_path+"/clips/"+video_name)
frames_list= sorted(frames_list)
ref_path = root_path+"/ref/"+video_name+"/"
ref_file = os.listdir(ref_path)[0]
ref_path = ref_path + ref_file
I_last_lab_predict = torch.zeros((1,3,224,224)).to(device)
video_out_path = out_path+"/"+video_name+"/"
os.mkdir(video_out_path)
ref_frame_pil_rgb=Image.open(ref_path).convert("RGB")
I_reference_lab, I_reference_padding= custom_transform(transformer,center_padder(ref_frame_pil_rgb))
I_reference_lab=torch.unsqueeze(I_reference_lab,0)
I_reference_l = I_reference_lab[:, 0:1, :, :]
I_reference_ab = I_reference_lab[:, 1:3, :, :]
I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device)
features_B = embed_net(I_reference_rgb)
for frame_name in frames_list:
# current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB")
# ref_frame_pil_rgb=Image.open(ref_path).convert("RGB")
# last_frame_pil_rgb=Image.open(las_path).convert("RGB")
#=================================using predictor but fail========================
# I_current_lab = predictor_instance.data_transform(current_frame_pil_rgb)
# I_current_lab = torch.unsqueeze(I_current_lab,0)
# I_current_l = I_current_lab[:, 0:1, :, :]
# I_current_ab = I_current_lab[:, 1:3, :, :]
# # print(I_current_l[0,0,112:122,112:122])
# # I_last_lab = predictor_instance.data_transform(last_frame_pil_rgb)
# # I_last_lab = torch.unsqueeze(I_last_lab,0)
# # I_last_l = I_last_lab[:, 0:1, :, :]
# # I_last_ab = I_last_lab[:, 1:3, :, :]
# I_current_lab_predict= predictor_instance(I_current_l=I_current_l,ref_img=ref_frame_pil_rgb,I_last_lab_predict=I_last_lab_predict)
# I_current_l_predict=I_current_lab_predict[:, 0:1, :, :]
# I_current_ab_predict=I_current_lab_predict[:, 1:3, :, :]
# I_current_rgb_predict = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l_predict), I_current_ab_predict), dim=1))
# image_result2 = Image.fromarray((I_current_rgb_predict[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8))
# I_last_lab_predict = I_current_lab_predict
#=================================using predictor but fail========================
current_frame_pil_rgb=Image.open(root_path+"/clips/"+video_name+"/"+frame_name).convert("RGB")
im_w,im_h=current_frame_pil_rgb.size
# ref_frame_pil_rgb.show()
I_current_lab,I_current_padding = custom_transform(transformer,current_frame_pil_rgb)
I_current_lab=torch.unsqueeze(I_current_lab,0)
I_current_l = I_current_lab[:, 0:1, :, :]
I_current_ab = I_current_lab[:, 1:3, :, :]
# save_numpy(video_out_path+"/"+frame_name,I_current_l,"current_I")
# save_numpy(video_out_path+"/"+frame_name,I_reference_lab,"reference_lab")
# save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"I_last_lab_predict")
with torch.no_grad():
I_current_ab_predict,_ = frame_colorization(
IA_l=I_current_l,
IB_lab=I_reference_lab,
IA_last_lab=I_last_lab_predict,
features_B=features_B,
embed_net=embed_net,
colornet=colornet,
nonlocal_net=nonlocal_net,
luminance_noise=False,
#temperature=1e-10,
)
if high_resolution:
high_lab=transforms.Compose([
SquaredPadding(target_size=max(im_h,im_w)),
RGB2Lab(),
ToTensor(),
Normalize(),
])
# print(im_h)
# print(im_w)
high_lab_current = high_lab(current_frame_pil_rgb)
high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device)
high_l_current = high_lab_current[:, 0:1, :, :]
high_ab_current = high_lab_current[:, 1:3, :, :]
upsampler=torch.nn.Upsample(scale_factor=max(im_h,im_w)/224,mode="bilinear")
high_ab_predict = upsampler(I_current_ab_predict)
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1))
else:
I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_current_l), I_current_ab_predict), dim=1))
# I_predict_rgb = unpadder(I_predict_rgb,I_current_padding)
image_result2 = Image.fromarray((I_predict_rgb[0] * 255).permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8))
I_last_lab_predict = torch.cat((I_current_l, I_current_ab_predict), dim=1)
# save_numpy(video_out_path+"/"+frame_name,I_last_lab_predict,"result_lab")
image_result2.save(video_out_path+"/"+frame_name)
# image_result2.show()