import numpy as np import shutil import os import argparse import torch import glob from tqdm import tqdm from PIL import Image from collections import OrderedDict from src.models.vit.config import load_config import torchvision.transforms as transforms import cv2 from skimage import io from src.models.CNN.ColorVidNet import GeneralColorVidNet from src.models.vit.embed import GeneralEmbedModel from src.models.CNN.NonlocalNet import GeneralWarpNet from src.models.CNN.FrameColor import frame_colorization from src.utils import ( RGB2Lab, ToTensor, Normalize, uncenter_l, tensor_lab2rgb, SquaredPadding, UnpaddingSquare ) import gradio as gr def load_params(ckpt_file): params = torch.load(ckpt_file, map_location=device) new_params = [] for key, value in params.items(): new_params.append((key, value)) return OrderedDict(new_params) def custom_transform(transforms, img): for transform in transforms: if isinstance(transform, SquaredPadding): img,padding=transform(img, return_paddings=True) else: img = transform(img) return img.to(device), padding def save_frames(predicted_rgb, video_name, frame_name): if predicted_rgb is not None: predicted_rgb = np.clip(predicted_rgb, 0, 255).astype(np.uint8) # frame_path_parts = frame_path.split(os.sep) # if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])): # shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2])) # os.makedirs(os.path.join(OUTPUT_RESULT_PATH, frame_path_parts[-2]), exist_ok=True) predicted_rgb = np.transpose(predicted_rgb, (1,2,0)) pil_img = Image.fromarray(predicted_rgb) pil_img.save(os.path.join(OUTPUT_RESULT_PATH, video_name, frame_name)) def extract_frames_from_video(video_path): cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) # remove if exists folder output_frames_path = os.path.join(INPUT_VIDEO_FRAMES_PATH, os.path.basename(video_path)) if os.path.exists(output_frames_path): shutil.rmtree(output_frames_path) # make new folder os.makedirs(output_frames_path) currentframe = 0 frame_path_list = [] while(True): # reading from frame ret,frame = cap.read() if ret: name = os.path.join(output_frames_path, f'{currentframe:09d}.jpg') frame_path_list.append(name) cv2.imwrite(name, frame) currentframe += 1 else: break cap.release() cv2.destroyAllWindows() return frame_path_list, fps def combine_frames_from_folder(frames_list_path, fps = 30): frames_list = glob.glob(f'{frames_list_path}/*.jpg') frames_list.sort() sample_shape = cv2.imread(frames_list[0]).shape output_video_path = os.path.join(frames_list_path, 'output_video.mp4') out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (sample_shape[1], sample_shape[0])) for filename in frames_list: img = cv2.imread(filename) out.write(img) out.release() return output_video_path def upscale_image(I_current_rgb, I_current_ab_predict): H, W = I_current_rgb.size high_lab_transforms = [ SquaredPadding(target_size=max(H,W)), RGB2Lab(), ToTensor(), Normalize() ] # current_frame_pil_rgb = Image.fromarray(np.clip(I_current_rgb.squeeze(0).permute(1,2,0).cpu().numpy() * 255, 0, 255).astype('uint8')) high_lab_current, paddings = custom_transform(high_lab_transforms, I_current_rgb) high_lab_current = torch.unsqueeze(high_lab_current,dim=0).to(device) high_l_current = high_lab_current[:, 0:1, :, :] high_ab_current = high_lab_current[:, 1:3, :, :] upsampler = torch.nn.Upsample(scale_factor=max(H,W)/224,mode="bilinear") high_ab_predict = upsampler(I_current_ab_predict) I_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(high_l_current), high_ab_predict), dim=1)) upadded = UnpaddingSquare() I_predict_rgb = upadded(I_predict_rgb, paddings) return I_predict_rgb def colorize_video(video_path, ref_np): frames_list, fps = extract_frames_from_video(video_path) frame_ref = Image.fromarray(ref_np).convert("RGB") I_last_lab_predict = None IB_lab, IB_paddings = custom_transform(transforms, frame_ref) IB_lab = IB_lab.unsqueeze(0).to(device) IB_l = IB_lab[:, 0:1, :, :] IB_ab = IB_lab[:, 1:3, :, :] with torch.no_grad(): I_reference_lab = IB_lab I_reference_l = I_reference_lab[:, 0:1, :, :] I_reference_ab = I_reference_lab[:, 1:3, :, :] I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)).to(device) features_B = embed_net(I_reference_rgb) video_path_parts = frames_list[0].split(os.sep) if os.path.exists(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])): shutil.rmtree(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2])) os.makedirs(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), exist_ok=True) for frame_path in tqdm(frames_list): curr_frame = Image.open(frame_path).convert("RGB") IA_lab, IA_paddings = custom_transform(transforms, curr_frame) IA_lab = IA_lab.unsqueeze(0).to(device) IA_l = IA_lab[:, 0:1, :, :] IA_ab = IA_lab[:, 1:3, :, :] if I_last_lab_predict is None: I_last_lab_predict = torch.zeros_like(IA_lab).to(device) with torch.no_grad(): I_current_lab = IA_lab I_current_ab_predict, _ = frame_colorization( IA_l, I_reference_lab, I_last_lab_predict, features_B, embed_net, nonlocal_net, colornet, luminance_noise=0, temperature=1e-10, joint_training=False ) I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) # IA_predict_rgb = tensor_lab2rgb(torch.cat((uncenter_l(IA_l), I_current_ab_predict), dim=1)) IA_predict_rgb = upscale_image(curr_frame, I_current_ab_predict) #IA_predict_rgb = torch.nn.functional.upsample_bilinear(IA_predict_rgb, scale_factor=2) save_frames(IA_predict_rgb.squeeze(0).cpu().numpy() * 255, video_path_parts[-2], os.path.basename(frame_path)) return combine_frames_from_folder(os.path.join(OUTPUT_RESULT_PATH, video_path_parts[-2]), fps) if __name__ == '__main__': # Init global variables device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') INPUT_VIDEO_FRAMES_PATH = 'inputs' OUTPUT_RESULT_PATH = 'outputs' weight_path = 'checkpoints' embed_net=GeneralEmbedModel(pretrained_model="swin-tiny", device=device).to(device) nonlocal_net = GeneralWarpNet(feature_channel=128).to(device) colornet=GeneralColorVidNet(7).to(device) embed_net.eval() nonlocal_net.eval() colornet.eval() # Load weights # embed_net_params = load_params(os.path.join(weight_path, "embed_net.pth")) nonlocal_net_params = load_params(os.path.join(weight_path, "nonlocal_net.pth")) colornet_params = load_params(os.path.join(weight_path, "colornet.pth")) # embed_net.load_state_dict(embed_net_params, strict=True) nonlocal_net.load_state_dict(nonlocal_net_params, strict=True) colornet.load_state_dict(colornet_params, strict=True) transforms = [SquaredPadding(target_size=224), RGB2Lab(), ToTensor(), Normalize()] #examples = [[vid, ref] for vid, ref in zip(sorted(glob.glob('examples/*/*.mp4')), sorted(glob.glob('examples/*/*.jpg')))] demo = gr.Interface(colorize_video, inputs=[gr.Video(), gr.Image()], outputs="playable_video")#, #examples=examples, #cache_examples=True) demo.launch()