ychenhq's picture
Upload folder using huggingface_hub
04fbff5 verified
import os
import sys
import tqdm
import torch
import argparse
import numpy as np
import os.path as osp
from omegaconf import OmegaConf
sys.path.append('.')
from utils.build_utils import build_from_cfg
from metrics.psnr_ssim import calculate_psnr, calculate_ssim
from utils.utils import InputPadder, read, img2tensor
def parse_path(path):
path_list = path.split('/')
new_path = osp.join(*path_list[-3:])
return new_path
parser = argparse.ArgumentParser(
prog = 'AMT',
description = 'SNU-FILM evaluation',
)
parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml')
parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth')
parser.add_argument('-r', '--root', default='data/SNU_FILM')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cfg_path = args.config
ckpt_path = args.ckpt
root = args.root
network_cfg = OmegaConf.load(cfg_path).network
network_name = network_cfg.name
model = build_from_cfg(network_cfg)
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['state_dict'])
model = model.to(device)
model.eval()
divisor = 20; scale_factor = 0.8
splits = ['easy', 'medium', 'hard', 'extreme']
for split in splits:
with open(os.path.join(root, f'test-{split}.txt'), "r") as fr:
file_list = [l.strip().split(' ') for l in fr.readlines()]
pbar = tqdm.tqdm(file_list, total=len(file_list))
psnr_list = []; ssim_list = []
for name in pbar:
img0 = img2tensor(read(osp.join(root, parse_path(name[0])))).to(device)
imgt = img2tensor(read(osp.join(root, parse_path(name[1])))).to(device)
img1 = img2tensor(read(osp.join(root, parse_path(name[2])))).to(device)
padder = InputPadder(img0.shape, divisor)
img0, img1 = padder.pad(img0, img1)
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device)
imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred']
imgt_pred = padder.unpad(imgt_pred)
psnr = calculate_psnr(imgt_pred, imgt).detach().cpu().numpy()
ssim = calculate_ssim(imgt_pred, imgt).detach().cpu().numpy()
psnr_list.append(psnr)
ssim_list.append(ssim)
avg_psnr = np.mean(psnr_list)
avg_ssim = np.mean(ssim_list)
desc_str = f'[{network_name}/SNU-FILM] [{split}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}'
pbar.set_description_str(desc_str)