Spaces:
Runtime error
Runtime error
import sys | |
import tqdm | |
import torch | |
import argparse | |
import numpy as np | |
from omegaconf import OmegaConf | |
sys.path.append('.') | |
from utils.build_utils import build_from_cfg | |
from datasets.gopro_datasets import GoPro_Test_Dataset | |
from metrics.psnr_ssim import calculate_psnr, calculate_ssim | |
parser = argparse.ArgumentParser( | |
prog = 'AMT', | |
description = 'GOPRO evaluation', | |
) | |
parser.add_argument('-c', '--config', default='cfgs/AMT-S_gopro.yaml') | |
parser.add_argument('-p', '--ckpt', default='pretrained/gopro_amt-s.pth',) | |
parser.add_argument('-r', '--root', default='data/GOPRO',) | |
args = parser.parse_args() | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
cfg_path = args.config | |
ckpt_path = args.ckpt | |
root = args.root | |
network_cfg = OmegaConf.load(cfg_path).network | |
network_name = network_cfg.name | |
model = build_from_cfg(network_cfg) | |
ckpt = torch.load(ckpt_path) | |
model.load_state_dict(ckpt['state_dict']) | |
model = model.to(device) | |
model.eval() | |
dataset = GoPro_Test_Dataset(dataset_dir=root) | |
psnr_list = [] | |
ssim_list = [] | |
pbar = tqdm.tqdm(dataset, total=len(dataset)) | |
for data in pbar: | |
input_dict = {} | |
for k, v in data.items(): | |
input_dict[k] = v.to(device).unsqueeze(0) | |
with torch.no_grad(): | |
imgt_pred = model(**input_dict)['imgt_pred'] | |
psnr = calculate_psnr(imgt_pred, input_dict['imgt']) | |
ssim = calculate_ssim(imgt_pred, input_dict['imgt']) | |
psnr_list.append(psnr) | |
ssim_list.append(ssim) | |
avg_psnr = np.mean(psnr_list) | |
avg_ssim = np.mean(ssim_list) | |
desc_str = f'[{network_name}/GOPRO] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' | |
pbar.set_description_str(desc_str) | |