Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import torch | |
| import cv2 | |
| import imageio | |
| import numpy as np | |
| from cotracker.datasets.utils import CoTrackerData | |
| from torchvision.transforms import ColorJitter, GaussianBlur | |
| from PIL import Image | |
| class CoTrackerDataset(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| data_root, | |
| crop_size=(384, 512), | |
| seq_len=24, | |
| traj_per_sample=768, | |
| sample_vis_1st_frame=False, | |
| use_augs=False, | |
| ): | |
| super(CoTrackerDataset, self).__init__() | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| self.data_root = data_root | |
| self.seq_len = seq_len | |
| self.traj_per_sample = traj_per_sample | |
| self.sample_vis_1st_frame = sample_vis_1st_frame | |
| self.use_augs = use_augs | |
| self.crop_size = crop_size | |
| # photometric augmentation | |
| self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14) | |
| self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) | |
| self.blur_aug_prob = 0.25 | |
| self.color_aug_prob = 0.25 | |
| # occlusion augmentation | |
| self.eraser_aug_prob = 0.5 | |
| self.eraser_bounds = [2, 100] | |
| self.eraser_max = 10 | |
| # occlusion augmentation | |
| self.replace_aug_prob = 0.5 | |
| self.replace_bounds = [2, 100] | |
| self.replace_max = 10 | |
| # spatial augmentations | |
| self.pad_bounds = [0, 100] | |
| self.crop_size = crop_size | |
| self.resize_lim = [0.25, 2.0] # sample resizes from here | |
| self.resize_delta = 0.2 | |
| self.max_crop_offset = 50 | |
| self.do_flip = True | |
| self.h_flip_prob = 0.5 | |
| self.v_flip_prob = 0.5 | |
| def getitem_helper(self, index): | |
| return NotImplementedError | |
| def __getitem__(self, index): | |
| gotit = False | |
| sample, gotit = self.getitem_helper(index) | |
| if not gotit: | |
| print("warning: sampling failed") | |
| # fake sample, so we can still collate | |
| sample = CoTrackerData( | |
| video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])), | |
| trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), | |
| visibility=torch.zeros((self.seq_len, self.traj_per_sample)), | |
| valid=torch.zeros((self.seq_len, self.traj_per_sample)), | |
| ) | |
| return sample, gotit | |
| def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): | |
| T, N, _ = trajs.shape | |
| S = len(rgbs) | |
| H, W = rgbs[0].shape[:2] | |
| assert S == T | |
| if eraser: | |
| ############ eraser transform (per image after the first) ############ | |
| rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
| for i in range(1, S): | |
| if np.random.rand() < self.eraser_aug_prob: | |
| for _ in range( | |
| np.random.randint(1, self.eraser_max + 1) | |
| ): # number of times to occlude | |
| xc = np.random.randint(0, W) | |
| yc = np.random.randint(0, H) | |
| dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) | |
| dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) | |
| x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | |
| x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | |
| y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | |
| y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | |
| mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0) | |
| rgbs[i][y0:y1, x0:x1, :] = mean_color | |
| occ_inds = np.logical_and( | |
| np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | |
| np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | |
| ) | |
| visibles[i, occ_inds] = 0 | |
| rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | |
| if replace: | |
| rgbs_alt = [ | |
| np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs | |
| ] | |
| rgbs_alt = [ | |
| np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt | |
| ] | |
| ############ replace transform (per image after the first) ############ | |
| rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
| rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] | |
| for i in range(1, S): | |
| if np.random.rand() < self.replace_aug_prob: | |
| for _ in range( | |
| np.random.randint(1, self.replace_max + 1) | |
| ): # number of times to occlude | |
| xc = np.random.randint(0, W) | |
| yc = np.random.randint(0, H) | |
| dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) | |
| dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) | |
| x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) | |
| x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) | |
| y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) | |
| y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) | |
| wid = x1 - x0 | |
| hei = y1 - y0 | |
| y00 = np.random.randint(0, H - hei) | |
| x00 = np.random.randint(0, W - wid) | |
| fr = np.random.randint(0, S) | |
| rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] | |
| rgbs[i][y0:y1, x0:x1, :] = rep | |
| occ_inds = np.logical_and( | |
| np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), | |
| np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), | |
| ) | |
| visibles[i, occ_inds] = 0 | |
| rgbs = [rgb.astype(np.uint8) for rgb in rgbs] | |
| ############ photometric augmentation ############ | |
| if np.random.rand() < self.color_aug_prob: | |
| # random per-frame amount of aug | |
| rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] | |
| if np.random.rand() < self.blur_aug_prob: | |
| # random per-frame amount of blur | |
| rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] | |
| return rgbs, trajs, visibles | |
| def add_spatial_augs(self, rgbs, trajs, visibles): | |
| T, N, __ = trajs.shape | |
| S = len(rgbs) | |
| H, W = rgbs[0].shape[:2] | |
| assert S == T | |
| rgbs = [rgb.astype(np.float32) for rgb in rgbs] | |
| ############ spatial transform ############ | |
| # padding | |
| pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
| pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
| pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
| pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) | |
| rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs] | |
| trajs[:, :, 0] += pad_x0 | |
| trajs[:, :, 1] += pad_y0 | |
| H, W = rgbs[0].shape[:2] | |
| # scaling + stretching | |
| scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) | |
| scale_x = scale | |
| scale_y = scale | |
| H_new = H | |
| W_new = W | |
| scale_delta_x = 0.0 | |
| scale_delta_y = 0.0 | |
| rgbs_scaled = [] | |
| for s in range(S): | |
| if s == 1: | |
| scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) | |
| scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) | |
| elif s > 1: | |
| scale_delta_x = ( | |
| scale_delta_x * 0.8 | |
| + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | |
| ) | |
| scale_delta_y = ( | |
| scale_delta_y * 0.8 | |
| + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 | |
| ) | |
| scale_x = scale_x + scale_delta_x | |
| scale_y = scale_y + scale_delta_y | |
| # bring h/w closer | |
| scale_xy = (scale_x + scale_y) * 0.5 | |
| scale_x = scale_x * 0.5 + scale_xy * 0.5 | |
| scale_y = scale_y * 0.5 + scale_xy * 0.5 | |
| # don't get too crazy | |
| scale_x = np.clip(scale_x, 0.2, 2.0) | |
| scale_y = np.clip(scale_y, 0.2, 2.0) | |
| H_new = int(H * scale_y) | |
| W_new = int(W * scale_x) | |
| # make it at least slightly bigger than the crop area, | |
| # so that the random cropping can add diversity | |
| H_new = np.clip(H_new, self.crop_size[0] + 10, None) | |
| W_new = np.clip(W_new, self.crop_size[1] + 10, None) | |
| # recompute scale in case we clipped | |
| scale_x = (W_new - 1) / float(W - 1) | |
| scale_y = (H_new - 1) / float(H - 1) | |
| rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)) | |
| trajs[s, :, 0] *= scale_x | |
| trajs[s, :, 1] *= scale_y | |
| rgbs = rgbs_scaled | |
| ok_inds = visibles[0, :] > 0 | |
| vis_trajs = trajs[:, ok_inds] # S,?,2 | |
| if vis_trajs.shape[1] > 0: | |
| mid_x = np.mean(vis_trajs[0, :, 0]) | |
| mid_y = np.mean(vis_trajs[0, :, 1]) | |
| else: | |
| mid_y = self.crop_size[0] | |
| mid_x = self.crop_size[1] | |
| x0 = int(mid_x - self.crop_size[1] // 2) | |
| y0 = int(mid_y - self.crop_size[0] // 2) | |
| offset_x = 0 | |
| offset_y = 0 | |
| for s in range(S): | |
| # on each frame, shift a bit more | |
| if s == 1: | |
| offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset) | |
| offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset) | |
| elif s > 1: | |
| offset_x = int( | |
| offset_x * 0.8 | |
| + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 | |
| ) | |
| offset_y = int( | |
| offset_y * 0.8 | |
| + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 | |
| ) | |
| x0 = x0 + offset_x | |
| y0 = y0 + offset_y | |
| H_new, W_new = rgbs[s].shape[:2] | |
| if H_new == self.crop_size[0]: | |
| y0 = 0 | |
| else: | |
| y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) | |
| if W_new == self.crop_size[1]: | |
| x0 = 0 | |
| else: | |
| x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) | |
| rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] | |
| trajs[s, :, 0] -= x0 | |
| trajs[s, :, 1] -= y0 | |
| H_new = self.crop_size[0] | |
| W_new = self.crop_size[1] | |
| # flip | |
| h_flipped = False | |
| v_flipped = False | |
| if self.do_flip: | |
| # h flip | |
| if np.random.rand() < self.h_flip_prob: | |
| h_flipped = True | |
| rgbs = [rgb[:, ::-1] for rgb in rgbs] | |
| # v flip | |
| if np.random.rand() < self.v_flip_prob: | |
| v_flipped = True | |
| rgbs = [rgb[::-1] for rgb in rgbs] | |
| if h_flipped: | |
| trajs[:, :, 0] = W_new - trajs[:, :, 0] | |
| if v_flipped: | |
| trajs[:, :, 1] = H_new - trajs[:, :, 1] | |
| return rgbs, trajs | |
| def crop(self, rgbs, trajs): | |
| T, N, _ = trajs.shape | |
| S = len(rgbs) | |
| H, W = rgbs[0].shape[:2] | |
| assert S == T | |
| ############ spatial transform ############ | |
| H_new = H | |
| W_new = W | |
| # simple random crop | |
| y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0]) | |
| x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1]) | |
| rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] | |
| trajs[:, :, 0] -= x0 | |
| trajs[:, :, 1] -= y0 | |
| return rgbs, trajs | |
| class KubricMovifDataset(CoTrackerDataset): | |
| def __init__( | |
| self, | |
| data_root, | |
| crop_size=(384, 512), | |
| seq_len=24, | |
| traj_per_sample=768, | |
| sample_vis_1st_frame=False, | |
| use_augs=False, | |
| ): | |
| super(KubricMovifDataset, self).__init__( | |
| data_root=data_root, | |
| crop_size=crop_size, | |
| seq_len=seq_len, | |
| traj_per_sample=traj_per_sample, | |
| sample_vis_1st_frame=sample_vis_1st_frame, | |
| use_augs=use_augs, | |
| ) | |
| self.pad_bounds = [0, 25] | |
| self.resize_lim = [0.75, 1.25] # sample resizes from here | |
| self.resize_delta = 0.05 | |
| self.max_crop_offset = 15 | |
| self.seq_names = [ | |
| fname | |
| for fname in os.listdir(data_root) | |
| if os.path.isdir(os.path.join(data_root, fname)) | |
| ] | |
| print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) | |
| def getitem_helper(self, index): | |
| gotit = True | |
| seq_name = self.seq_names[index] | |
| npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") | |
| rgb_path = os.path.join(self.data_root, seq_name, "frames") | |
| img_paths = sorted(os.listdir(rgb_path)) | |
| rgbs = [] | |
| for i, img_path in enumerate(img_paths): | |
| rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) | |
| rgbs = np.stack(rgbs) | |
| annot_dict = np.load(npy_path, allow_pickle=True).item() | |
| traj_2d = annot_dict["coords"] | |
| visibility = annot_dict["visibility"] | |
| # random crop | |
| assert self.seq_len <= len(rgbs) | |
| if self.seq_len < len(rgbs): | |
| start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] | |
| rgbs = rgbs[start_ind : start_ind + self.seq_len] | |
| traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] | |
| visibility = visibility[:, start_ind : start_ind + self.seq_len] | |
| traj_2d = np.transpose(traj_2d, (1, 0, 2)) | |
| visibility = np.transpose(np.logical_not(visibility), (1, 0)) | |
| if self.use_augs: | |
| rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility) | |
| rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) | |
| else: | |
| rgbs, traj_2d = self.crop(rgbs, traj_2d) | |
| visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False | |
| visibility[traj_2d[:, :, 0] < 0] = False | |
| visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False | |
| visibility[traj_2d[:, :, 1] < 0] = False | |
| visibility = torch.from_numpy(visibility) | |
| traj_2d = torch.from_numpy(traj_2d) | |
| visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] | |
| if self.sample_vis_1st_frame: | |
| visibile_pts_inds = visibile_pts_first_frame_inds | |
| else: | |
| visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[ | |
| :, 0 | |
| ] | |
| visibile_pts_inds = torch.cat( | |
| (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 | |
| ) | |
| point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] | |
| if len(point_inds) < self.traj_per_sample: | |
| gotit = False | |
| visible_inds_sampled = visibile_pts_inds[point_inds] | |
| trajs = traj_2d[:, visible_inds_sampled].float() | |
| visibles = visibility[:, visible_inds_sampled] | |
| valids = torch.ones((self.seq_len, self.traj_per_sample)) | |
| rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() | |
| sample = CoTrackerData( | |
| video=rgbs, | |
| trajectory=trajs, | |
| visibility=visibles, | |
| valid=valids, | |
| seq_name=seq_name, | |
| ) | |
| return sample, gotit | |
| def __len__(self): | |
| return len(self.seq_names) | |