Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import random | |
from copy import deepcopy | |
import numpy as np | |
import torch | |
from iopath.common.file_io import g_pathmgr | |
from PIL import Image as PILImage | |
from torchvision.datasets.vision import VisionDataset | |
from training.dataset.vos_raw_dataset import VOSRawDataset | |
from training.dataset.vos_sampler import VOSSampler | |
from training.dataset.vos_segment_loader import JSONSegmentLoader | |
from training.utils.data_utils import Frame, Object, VideoDatapoint | |
MAX_RETRIES = 100 | |
class VOSDataset(VisionDataset): | |
def __init__( | |
self, | |
transforms, | |
training: bool, | |
video_dataset: VOSRawDataset, | |
sampler: VOSSampler, | |
multiplier: int, | |
always_target=True, | |
target_segments_available=True, | |
): | |
self._transforms = transforms | |
self.training = training | |
self.video_dataset = video_dataset | |
self.sampler = sampler | |
self.repeat_factors = torch.ones(len(self.video_dataset), dtype=torch.float32) | |
self.repeat_factors *= multiplier | |
print(f"Raw dataset length = {len(self.video_dataset)}") | |
self.curr_epoch = 0 # Used in case data loader behavior changes across epochs | |
self.always_target = always_target | |
self.target_segments_available = target_segments_available | |
def _get_datapoint(self, idx): | |
for retry in range(MAX_RETRIES): | |
try: | |
if isinstance(idx, torch.Tensor): | |
idx = idx.item() | |
# sample a video | |
video, segment_loader = self.video_dataset.get_video(idx) | |
# sample frames and object indices to be used in a datapoint | |
sampled_frms_and_objs = self.sampler.sample( | |
video, segment_loader, epoch=self.curr_epoch | |
) | |
break # Succesfully loaded video | |
except Exception as e: | |
if self.training: | |
logging.warning( | |
f"Loading failed (id={idx}); Retry {retry} with exception: {e}" | |
) | |
idx = random.randrange(0, len(self.video_dataset)) | |
else: | |
# Shouldn't fail to load a val video | |
raise e | |
datapoint = self.construct(video, sampled_frms_and_objs, segment_loader) | |
for transform in self._transforms: | |
datapoint = transform(datapoint, epoch=self.curr_epoch) | |
return datapoint | |
def construct(self, video, sampled_frms_and_objs, segment_loader): | |
""" | |
Constructs a VideoDatapoint sample to pass to transforms | |
""" | |
sampled_frames = sampled_frms_and_objs.frames | |
sampled_object_ids = sampled_frms_and_objs.object_ids | |
images = [] | |
rgb_images = load_images(sampled_frames) | |
# Iterate over the sampled frames and store their rgb data and object data (bbox, segment) | |
for frame_idx, frame in enumerate(sampled_frames): | |
w, h = rgb_images[frame_idx].size | |
images.append( | |
Frame( | |
data=rgb_images[frame_idx], | |
objects=[], | |
) | |
) | |
# We load the gt segments associated with the current frame | |
if isinstance(segment_loader, JSONSegmentLoader): | |
segments = segment_loader.load( | |
frame.frame_idx, obj_ids=sampled_object_ids | |
) | |
else: | |
segments = segment_loader.load(frame.frame_idx) | |
for obj_id in sampled_object_ids: | |
# Extract the segment | |
if obj_id in segments: | |
assert ( | |
segments[obj_id] is not None | |
), "None targets are not supported" | |
# segment is uint8 and remains uint8 throughout the transforms | |
segment = segments[obj_id].to(torch.uint8) | |
else: | |
# There is no target, we either use a zero mask target or drop this object | |
if not self.always_target: | |
continue | |
segment = torch.zeros(h, w, dtype=torch.uint8) | |
images[frame_idx].objects.append( | |
Object( | |
object_id=obj_id, | |
frame_index=frame.frame_idx, | |
segment=segment, | |
) | |
) | |
return VideoDatapoint( | |
frames=images, | |
video_id=video.video_id, | |
size=(h, w), | |
) | |
def __getitem__(self, idx): | |
return self._get_datapoint(idx) | |
def __len__(self): | |
return len(self.video_dataset) | |
def load_images(frames): | |
all_images = [] | |
cache = {} | |
for frame in frames: | |
if frame.data is None: | |
# Load the frame rgb data from file | |
path = frame.image_path | |
if path in cache: | |
all_images.append(deepcopy(all_images[cache[path]])) | |
continue | |
with g_pathmgr.open(path, "rb") as fopen: | |
all_images.append(PILImage.open(fopen).convert("RGB")) | |
cache[path] = len(all_images) - 1 | |
else: | |
# The frame rgb data has already been loaded | |
# Convert it to a PILImage | |
all_images.append(tensor_2_PIL(frame.data)) | |
return all_images | |
def tensor_2_PIL(data: torch.Tensor) -> PILImage.Image: | |
data = data.cpu().numpy().transpose((1, 2, 0)) * 255.0 | |
data = data.astype(np.uint8) | |
return PILImage.fromarray(data) | |