Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import json | |
| import random | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from utils.file_client import FileClient | |
| from utils.img_util import imfrombytes | |
| from utils.flow_util import resize_flow, flowread | |
| from core.utils import (create_random_shape_with_random_motion, Stack, | |
| ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip) | |
| class TrainDataset(torch.utils.data.Dataset): | |
| def __init__(self, args: dict): | |
| self.args = args | |
| self.video_root = args['video_root'] | |
| self.flow_root = args['flow_root'] | |
| self.num_local_frames = args['num_local_frames'] | |
| self.num_ref_frames = args['num_ref_frames'] | |
| self.size = self.w, self.h = (args['w'], args['h']) | |
| self.load_flow = args['load_flow'] | |
| if self.load_flow: | |
| assert os.path.exists(self.flow_root) | |
| json_path = os.path.join('./datasets', args['name'], 'train.json') | |
| with open(json_path, 'r') as f: | |
| self.video_train_dict = json.load(f) | |
| self.video_names = sorted(list(self.video_train_dict.keys())) | |
| # self.video_names = sorted(os.listdir(self.video_root)) | |
| self.video_dict = {} | |
| self.frame_dict = {} | |
| for v in self.video_names: | |
| frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) | |
| v_len = len(frame_list) | |
| if v_len > self.num_local_frames + self.num_ref_frames: | |
| self.video_dict[v] = v_len | |
| self.frame_dict[v] = frame_list | |
| self.video_names = list(self.video_dict.keys()) # update names | |
| self._to_tensors = transforms.Compose([ | |
| Stack(), | |
| ToTorchFormatTensor(), | |
| ]) | |
| self.file_client = FileClient('disk') | |
| def __len__(self): | |
| return len(self.video_names) | |
| def _sample_index(self, length, sample_length, num_ref_frame=3): | |
| complete_idx_set = list(range(length)) | |
| pivot = random.randint(0, length - sample_length) | |
| local_idx = complete_idx_set[pivot:pivot + sample_length] | |
| remain_idx = list(set(complete_idx_set) - set(local_idx)) | |
| ref_index = sorted(random.sample(remain_idx, num_ref_frame)) | |
| return local_idx + ref_index | |
| def __getitem__(self, index): | |
| video_name = self.video_names[index] | |
| # create masks | |
| all_masks = create_random_shape_with_random_motion( | |
| self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) | |
| # create sample index | |
| selected_index = self._sample_index(self.video_dict[video_name], | |
| self.num_local_frames, | |
| self.num_ref_frames) | |
| # read video frames | |
| frames = [] | |
| masks = [] | |
| flows_f, flows_b = [], [] | |
| for idx in selected_index: | |
| frame_list = self.frame_dict[video_name] | |
| img_path = os.path.join(self.video_root, video_name, frame_list[idx]) | |
| img_bytes = self.file_client.get(img_path, 'img') | |
| img = imfrombytes(img_bytes, float32=False) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
| img = Image.fromarray(img) | |
| frames.append(img) | |
| masks.append(all_masks[idx]) | |
| if len(frames) <= self.num_local_frames-1 and self.load_flow: | |
| current_n = frame_list[idx][:-4] | |
| next_n = frame_list[idx+1][:-4] | |
| flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') | |
| flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') | |
| flow_f = flowread(flow_f_path, quantize=False) | |
| flow_b = flowread(flow_b_path, quantize=False) | |
| flow_f = resize_flow(flow_f, self.h, self.w) | |
| flow_b = resize_flow(flow_b, self.h, self.w) | |
| flows_f.append(flow_f) | |
| flows_b.append(flow_b) | |
| if len(frames) == self.num_local_frames: # random reverse | |
| if random.random() < 0.5: | |
| frames.reverse() | |
| masks.reverse() | |
| if self.load_flow: | |
| flows_f.reverse() | |
| flows_b.reverse() | |
| flows_ = flows_f | |
| flows_f = flows_b | |
| flows_b = flows_ | |
| if self.load_flow: | |
| frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b) | |
| else: | |
| frames = GroupRandomHorizontalFlip()(frames) | |
| # normalizate, to tensors | |
| frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 | |
| mask_tensors = self._to_tensors(masks) | |
| if self.load_flow: | |
| flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 | |
| flows_b = np.stack(flows_b, axis=-1) | |
| flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() | |
| flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() | |
| # img [-1,1] mask [0,1] | |
| if self.load_flow: | |
| return frame_tensors, mask_tensors, flows_f, flows_b, video_name | |
| else: | |
| return frame_tensors, mask_tensors, 'None', 'None', video_name | |
| class TestDataset(torch.utils.data.Dataset): | |
| def __init__(self, args): | |
| self.args = args | |
| self.size = self.w, self.h = args['size'] | |
| self.video_root = args['video_root'] | |
| self.mask_root = args['mask_root'] | |
| self.flow_root = args['flow_root'] | |
| self.load_flow = args['load_flow'] | |
| if self.load_flow: | |
| assert os.path.exists(self.flow_root) | |
| self.video_names = sorted(os.listdir(self.mask_root)) | |
| self.video_dict = {} | |
| self.frame_dict = {} | |
| for v in self.video_names: | |
| frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) | |
| v_len = len(frame_list) | |
| self.video_dict[v] = v_len | |
| self.frame_dict[v] = frame_list | |
| self._to_tensors = transforms.Compose([ | |
| Stack(), | |
| ToTorchFormatTensor(), | |
| ]) | |
| self.file_client = FileClient('disk') | |
| def __len__(self): | |
| return len(self.video_names) | |
| def __getitem__(self, index): | |
| video_name = self.video_names[index] | |
| selected_index = list(range(self.video_dict[video_name])) | |
| # read video frames | |
| frames = [] | |
| masks = [] | |
| flows_f, flows_b = [], [] | |
| for idx in selected_index: | |
| frame_list = self.frame_dict[video_name] | |
| frame_path = os.path.join(self.video_root, video_name, frame_list[idx]) | |
| img_bytes = self.file_client.get(frame_path, 'input') | |
| img = imfrombytes(img_bytes, float32=False) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
| img = Image.fromarray(img) | |
| frames.append(img) | |
| mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png') | |
| mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L') | |
| # origin: 0 indicates missing. now: 1 indicates missing | |
| mask = np.asarray(mask) | |
| m = np.array(mask > 0).astype(np.uint8) | |
| m = cv2.dilate(m, | |
| cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), | |
| iterations=4) | |
| mask = Image.fromarray(m * 255) | |
| masks.append(mask) | |
| if len(frames) <= len(selected_index)-1 and self.load_flow: | |
| current_n = frame_list[idx][:-4] | |
| next_n = frame_list[idx+1][:-4] | |
| flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') | |
| flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') | |
| flow_f = flowread(flow_f_path, quantize=False) | |
| flow_b = flowread(flow_b_path, quantize=False) | |
| flow_f = resize_flow(flow_f, self.h, self.w) | |
| flow_b = resize_flow(flow_b, self.h, self.w) | |
| flows_f.append(flow_f) | |
| flows_b.append(flow_b) | |
| # normalizate, to tensors | |
| frames_PIL = [np.array(f).astype(np.uint8) for f in frames] | |
| frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 | |
| mask_tensors = self._to_tensors(masks) | |
| if self.load_flow: | |
| flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 | |
| flows_b = np.stack(flows_b, axis=-1) | |
| flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() | |
| flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() | |
| if self.load_flow: | |
| return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL | |
| else: | |
| return frame_tensors, mask_tensors, 'None', 'None', video_name |