Spaces:
Running
on
A10G
Running
on
A10G
| # -*- coding: utf-8 -*- | |
| import sys | |
| sys.path.append(".") | |
| import os | |
| import cv2 | |
| import argparse | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from RAFT import RAFT | |
| from utils.flow_util import * | |
| def imwrite(img, file_path, params=None, auto_mkdir=True): | |
| if auto_mkdir: | |
| dir_name = os.path.abspath(os.path.dirname(file_path)) | |
| os.makedirs(dir_name, exist_ok=True) | |
| return cv2.imwrite(file_path, img, params) | |
| def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): | |
| """Initializes the RAFT model. | |
| """ | |
| args = argparse.ArgumentParser() | |
| args.raft_model = model_path | |
| args.small = False | |
| args.mixed_precision = False | |
| args.alternate_corr = False | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(args.raft_model)) | |
| model = model.module | |
| model.to(device) | |
| model.eval() | |
| return model | |
| if __name__ == '__main__': | |
| device = 'cuda' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages') | |
| parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo') | |
| parser.add_argument('--height', type=int, default=240) | |
| parser.add_argument('--width', type=int, default=432) | |
| args = parser.parse_args() | |
| # Flow model | |
| RAFT_model = initialize_RAFT(device=device) | |
| root_path = args.root_path | |
| save_path = args.save_path | |
| h_new, w_new = (args.height, args.width) | |
| file_list = sorted(os.listdir(root_path)) | |
| for f in file_list: | |
| print(f'Processing: {f} ...') | |
| m_list = sorted(os.listdir(os.path.join(root_path, f))) | |
| len_m = len(m_list) | |
| for i in range(len_m-1): | |
| img1_path = os.path.join(root_path, f, m_list[i]) | |
| img2_path = os.path.join(root_path, f, m_list[i+1]) | |
| img1 = Image.fromarray(cv2.imread(img1_path)) | |
| img2 = Image.fromarray(cv2.imread(img2_path)) | |
| transform = transforms.Compose([transforms.ToTensor()]) | |
| img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:] | |
| img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:] | |
| # upsize to a multiple of 16 | |
| # h, w = img1.shape[2:4] | |
| # w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1) | |
| # h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1) | |
| img1 = F.interpolate(input=img1, | |
| size=(h_new, w_new), | |
| mode='bilinear', | |
| align_corners=False) | |
| img2 = F.interpolate(input=img2, | |
| size=(h_new, w_new), | |
| mode='bilinear', | |
| align_corners=False) | |
| with torch.no_grad(): | |
| img1 = img1*2 - 1 | |
| img2 = img2*2 - 1 | |
| _, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True) | |
| _, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True) | |
| flow_f = flow_f[0].permute(1,2,0).cpu().numpy() | |
| flow_b = flow_b[0].permute(1,2,0).cpu().numpy() | |
| # flow_f = resize_flow(flow_f, w_new, h_new) | |
| # flow_b = resize_flow(flow_b, w_new, h_new) | |
| save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo') | |
| save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo') | |
| flowwrite(flow_f, save_flow_f, quantize=False) | |
| flowwrite(flow_b, save_flow_b, quantize=False) | |