Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import cv2 | |
import tqdm | |
import glob | |
import torch | |
import argparse | |
import numpy as np | |
import os.path as osp | |
from omegaconf import OmegaConf | |
sys.path.append('.') | |
from utils.utils import InputPadder, read, img2tensor | |
from utils.build_utils import build_from_cfg | |
from metrics.psnr_ssim import calculate_psnr, calculate_ssim | |
parser = argparse.ArgumentParser( | |
prog = 'AMT', | |
description = 'Xiph evaluation', | |
) | |
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') | |
parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth') | |
parser.add_argument('-r', '--root', default='data/xiph') | |
args = parser.parse_args() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
cfg_path = args.config | |
ckpt_path = args.ckpt | |
root = args.root | |
network_cfg = OmegaConf.load(cfg_path).network | |
network_name = network_cfg.name | |
model = build_from_cfg(network_cfg) | |
ckpt = torch.load(ckpt_path) | |
model.load_state_dict(ckpt['state_dict'], False) | |
model = model.to(device) | |
model.eval() | |
############################################# Prepare Dataset ############################################# | |
download_links = [ | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m', | |
'https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m', | |
] | |
file_list = ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance', | |
'SquareAndTimelapse', 'Tango'] | |
for file_name, link in zip(file_list, download_links): | |
data_dir = osp.join(root, file_name) | |
if osp.exists(data_dir) is False: | |
os.makedirs(data_dir) | |
if len(glob.glob(f'{data_dir}/*.png')) < 100: | |
os.system(f'ffmpeg -i {link} -pix_fmt rgb24 -vframes 100 {data_dir}/%03d.png') | |
############################################### Prepare End ############################################### | |
divisor = 32; scale_factor = 0.5 | |
for category in ['resized-2k', 'cropped-4k']: | |
psnr_list = [] | |
ssim_list = [] | |
pbar = tqdm.tqdm(file_list, total=len(file_list)) | |
for flie_name in pbar: | |
dir_name = osp.join(root, flie_name) | |
for intFrame in range(2, 99, 2): | |
img0 = read(f'{dir_name}/{intFrame - 1:03d}.png') | |
img1 = read(f'{dir_name}/{intFrame + 1:03d}.png') | |
imgt = read(f'{dir_name}/{intFrame:03d}.png') | |
if category == 'resized-2k': | |
img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) | |
img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) | |
imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) | |
elif category == 'cropped-4k': | |
img0 = img0[540:-540, 1024:-1024, :] | |
img1 = img1[540:-540, 1024:-1024, :] | |
imgt = imgt[540:-540, 1024:-1024, :] | |
img0 = img2tensor(img0).to(device) | |
imgt = img2tensor(imgt).to(device) | |
img1 = img2tensor(img1).to(device) | |
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) | |
padder = InputPadder(img0.shape, divisor) | |
img0, img1 = padder.pad(img0, img1) | |
with torch.no_grad(): | |
imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred'] | |
imgt_pred = padder.unpad(imgt_pred) | |
psnr = calculate_psnr(imgt_pred, imgt) | |
ssim = calculate_ssim(imgt_pred, imgt) | |
avg_psnr = np.mean(psnr_list) | |
avg_ssim = np.mean(ssim_list) | |
psnr_list.append(psnr) | |
ssim_list.append(ssim) | |
desc_str = f'[{network_name}/Xiph] [{category}/{flie_name}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' | |
pbar.set_description_str(desc_str) |