File size: 5,202 Bytes
c705408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import gzip
import torch
import numpy as np
import as data
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int]
class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None
trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset):
def __init__(
super(DynamicReplicaDataset, self).__init__()
self.root = root
self.sample_len = sample_len
self.split = split
self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input
self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = []
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list)
for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left":
for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0
for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples:
def __len__(self):
return len(self.sample_list)
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
H_new = H
W_new = W
# simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
def __getitem__(self, index):
sample = self.sample_list[index]
T = len(sample)
rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size
image_size = (H, W)
for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path)
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities)
T, N, D = traj_2d.shape
# subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape
image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData(
valid=torch.ones(T, N),