Spaces:
Running
on
A10G
Running
on
A10G
| # -*- coding: utf-8 -*- | |
| import sys | |
| sys.path.append(".") | |
| import cv2 | |
| import os | |
| import numpy as np | |
| import argparse | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from core.dataset import TestDataset | |
| from model.modules.flow_comp_raft import RAFT_bi | |
| from model.recurrent_flow_completion import RecurrentFlowCompleteNet | |
| from RAFT.utils.flow_viz_pt import flow_to_image | |
| import cvbase | |
| import imageio | |
| from time import time | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| def create_dir(dir): | |
| """Creates a directory if not exist. | |
| """ | |
| if not os.path.exists(dir): | |
| os.makedirs(dir) | |
| def save_flows(output, videoFlowF, videoFlowB): | |
| # create_dir(os.path.join(output, 'forward_flo')) | |
| # create_dir(os.path.join(output, 'backward_flo')) | |
| create_dir(os.path.join(output, 'forward_png')) | |
| create_dir(os.path.join(output, 'backward_png')) | |
| N = videoFlowF.shape[-1] | |
| for i in range(N): | |
| forward_flow = videoFlowF[..., i] | |
| backward_flow = videoFlowB[..., i] | |
| forward_flow_vis = cvbase.flow2rgb(forward_flow) | |
| backward_flow_vis = cvbase.flow2rgb(backward_flow) | |
| # cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i))) | |
| # cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i))) | |
| forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8) | |
| backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8) | |
| imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis) | |
| imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis) | |
| def tensor2np(array): | |
| array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy() | |
| return array | |
| def main_worker(args): | |
| # set up datasets and data loader | |
| args.size = (args.width, args.height) | |
| test_dataset = TestDataset(vars(args)) | |
| test_loader = DataLoader(test_dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=args.num_workers) | |
| # set up models | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| fix_raft = RAFT_bi(args.raft_model_path, device) | |
| fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path) | |
| for p in fix_flow_complete.parameters(): | |
| p.requires_grad = False | |
| fix_flow_complete.to(device) | |
| fix_flow_complete.eval() | |
| total_frame_epe = [] | |
| time_all = [] | |
| print('Start evaluation...') | |
| # create results directory | |
| result_path = os.path.join('results_flow', f'{args.dataset}') | |
| if not os.path.exists(result_path): | |
| os.makedirs(result_path) | |
| eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w") | |
| for index, items in enumerate(test_loader): | |
| frames, masks, flows_f, flows_b, video_name, frames_PIL = items | |
| local_masks = masks.float().to(device) | |
| video_length = frames.size(1) | |
| if args.load_flow: | |
| gt_flows_bi = (flows_f.to(device), flows_b.to(device)) | |
| else: | |
| short_len = 60 | |
| if frames.size(1) > short_len: | |
| gt_flows_f_list, gt_flows_b_list = [], [] | |
| for f in range(0, video_length, short_len): | |
| end_f = min(video_length, f + short_len) | |
| if f == 0: | |
| flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter) | |
| else: | |
| flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter) | |
| gt_flows_f_list.append(flows_f) | |
| gt_flows_b_list.append(flows_b) | |
| gt_flows_f = torch.cat(gt_flows_f_list, dim=1) | |
| gt_flows_b = torch.cat(gt_flows_b_list, dim=1) | |
| gt_flows_bi = (gt_flows_f, gt_flows_b) | |
| else: | |
| gt_flows_bi = fix_raft(frames, iters=20) | |
| torch.cuda.synchronize() | |
| time_start = time() | |
| # flow_length = flows_f.size(1) | |
| # f_stride = 30 | |
| # pred_flows_f = [] | |
| # pred_flows_b = [] | |
| # suffix = flow_length%f_stride | |
| # last = flow_length//f_stride | |
| # for f in range(0, flow_length, f_stride): | |
| # gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride]) | |
| # pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1]) | |
| # pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1]) | |
| # pred_flows_f.append(pred_flows_f_i) | |
| # pred_flows_b.append(pred_flows_b_i) | |
| # pred_flows_f = torch.cat(pred_flows_f, dim=1) | |
| # pred_flows_b = torch.cat(pred_flows_b, dim=1) | |
| # pred_flows_bi = (pred_flows_f, pred_flows_b) | |
| pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) | |
| pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) | |
| torch.cuda.synchronize() | |
| time_i = time() - time_start | |
| time_i = time_i*1.0/frames.size(1) | |
| time_all = time_all+[time_i]*frames.size(1) | |
| cur_video_epe = [] | |
| epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt()) | |
| epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt()) | |
| cur_video_epe.append(epe1.numpy()) | |
| cur_video_epe.append(epe2.numpy()) | |
| total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1) | |
| total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1) | |
| cur_epe = sum(cur_video_epe) / len(cur_video_epe) | |
| avg_time = sum(time_all) / len(time_all) | |
| print( | |
| f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}' | |
| ) | |
| eval_summary.write( | |
| f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n' | |
| ) | |
| # saving images for evaluating warpping errors | |
| if args.save_results: | |
| forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4) | |
| backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4) | |
| # forward_flows = flows_f.cpu().permute(1,0,2,3,4) | |
| # backward_flows = flows_b.cpu().permute(1,0,2,3,4) | |
| videoFlowF = list(forward_flows) | |
| videoFlowB = list(backward_flows) | |
| videoFlowF = tensor2np(videoFlowF) | |
| videoFlowB = tensor2np(videoFlowB) | |
| save_frame_path = os.path.join(result_path, video_name[0]) | |
| save_flows(save_frame_path, videoFlowF, videoFlowB) | |
| avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe) | |
| print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}') | |
| eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n') | |
| eval_summary.close() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--height', type=int, default=240) | |
| parser.add_argument('--width', type=int, default=432) | |
| parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str) | |
| parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str) | |
| parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str) | |
| parser.add_argument('--video_root', default='dataset_root', type=str) | |
| parser.add_argument('--mask_root', default='mask_root', type=str) | |
| parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str) | |
| parser.add_argument('--load_flow', default=False, type=bool) | |
| parser.add_argument("--raft_iter", type=int, default=20) | |
| parser.add_argument('--save_results', action='store_true') | |
| parser.add_argument('--num_workers', default=4, type=int) | |
| args = parser.parse_args() | |
| main_worker(args) | |