Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import glob | |
import json | |
import os | |
import numpy as np | |
import pandas as pd | |
import torch | |
from PIL import Image as PILImage | |
try: | |
from pycocotools import mask as mask_utils | |
except: | |
pass | |
class JSONSegmentLoader: | |
def __init__(self, video_json_path, ann_every=1, frames_fps=24, valid_obj_ids=None): | |
# Annotations in the json are provided every ann_every th frame | |
self.ann_every = ann_every | |
# Ids of the objects to consider when sampling this video | |
self.valid_obj_ids = valid_obj_ids | |
with open(video_json_path, "r") as f: | |
data = json.load(f) | |
if isinstance(data, list): | |
self.frame_annots = data | |
elif isinstance(data, dict): | |
masklet_field_name = "masklet" if "masklet" in data else "masks" | |
self.frame_annots = data[masklet_field_name] | |
if "fps" in data: | |
if isinstance(data["fps"], list): | |
annotations_fps = int(data["fps"][0]) | |
else: | |
annotations_fps = int(data["fps"]) | |
assert frames_fps % annotations_fps == 0 | |
self.ann_every = frames_fps // annotations_fps | |
else: | |
raise NotImplementedError | |
def load(self, frame_id, obj_ids=None): | |
assert frame_id % self.ann_every == 0 | |
rle_mask = self.frame_annots[frame_id // self.ann_every] | |
valid_objs_ids = set(range(len(rle_mask))) | |
if self.valid_obj_ids is not None: | |
# Remove the masklets that have been filtered out for this video | |
valid_objs_ids &= set(self.valid_obj_ids) | |
if obj_ids is not None: | |
# Only keep the objects that have been sampled | |
valid_objs_ids &= set(obj_ids) | |
valid_objs_ids = sorted(list(valid_objs_ids)) | |
# Construct rle_masks_filtered that only contains the rle masks we are interested in | |
id_2_idx = {} | |
rle_mask_filtered = [] | |
for obj_id in valid_objs_ids: | |
if rle_mask[obj_id] is not None: | |
id_2_idx[obj_id] = len(rle_mask_filtered) | |
rle_mask_filtered.append(rle_mask[obj_id]) | |
else: | |
id_2_idx[obj_id] = None | |
# Decode the masks | |
raw_segments = torch.from_numpy(mask_utils.decode(rle_mask_filtered)).permute( | |
2, 0, 1 | |
) # (num_obj, h, w) | |
segments = {} | |
for obj_id in valid_objs_ids: | |
if id_2_idx[obj_id] is None: | |
segments[obj_id] = None | |
else: | |
idx = id_2_idx[obj_id] | |
segments[obj_id] = raw_segments[idx] | |
return segments | |
def get_valid_obj_frames_ids(self, num_frames_min=None): | |
# For each object, find all the frames with a valid (not None) mask | |
num_objects = len(self.frame_annots[0]) | |
# The result dict associates each obj_id with the id of its valid frames | |
res = {obj_id: [] for obj_id in range(num_objects)} | |
for annot_idx, annot in enumerate(self.frame_annots): | |
for obj_id in range(num_objects): | |
if annot[obj_id] is not None: | |
res[obj_id].append(int(annot_idx * self.ann_every)) | |
if num_frames_min is not None: | |
# Remove masklets that have less than num_frames_min valid masks | |
for obj_id, valid_frames in list(res.items()): | |
if len(valid_frames) < num_frames_min: | |
res.pop(obj_id) | |
return res | |
class PalettisedPNGSegmentLoader: | |
def __init__(self, video_png_root): | |
""" | |
SegmentLoader for datasets with masks stored as palettised PNGs. | |
video_png_root: the folder contains all the masks stored in png | |
""" | |
self.video_png_root = video_png_root | |
# build a mapping from frame id to their PNG mask path | |
# note that in some datasets, the PNG paths could have more | |
# than 5 digits, e.g. "00000000.png" instead of "00000.png" | |
png_filenames = os.listdir(self.video_png_root) | |
self.frame_id_to_png_filename = {} | |
for filename in png_filenames: | |
frame_id, _ = os.path.splitext(filename) | |
self.frame_id_to_png_filename[int(frame_id)] = filename | |
def load(self, frame_id): | |
""" | |
load the single palettised mask from the disk (path: f'{self.video_png_root}/{frame_id:05d}.png') | |
Args: | |
frame_id: int, define the mask path | |
Return: | |
binary_segments: dict | |
""" | |
# check the path | |
mask_path = os.path.join( | |
self.video_png_root, self.frame_id_to_png_filename[frame_id] | |
) | |
# load the mask | |
masks = PILImage.open(mask_path).convert("P") | |
masks = np.array(masks) | |
object_id = pd.unique(masks.flatten()) | |
object_id = object_id[object_id != 0] # remove background (0) | |
# convert into N binary segmentation masks | |
binary_segments = {} | |
for i in object_id: | |
bs = masks == i | |
binary_segments[i] = torch.from_numpy(bs) | |
return binary_segments | |
def __len__(self): | |
return | |
class MultiplePNGSegmentLoader: | |
def __init__(self, video_png_root, single_object_mode=False): | |
""" | |
video_png_root: the folder contains all the masks stored in png | |
single_object_mode: whether to load only a single object at a time | |
""" | |
self.video_png_root = video_png_root | |
self.single_object_mode = single_object_mode | |
# read a mask to know the resolution of the video | |
if self.single_object_mode: | |
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*.png"))[0] | |
else: | |
tmp_mask_path = glob.glob(os.path.join(video_png_root, "*", "*.png"))[0] | |
tmp_mask = np.array(PILImage.open(tmp_mask_path)) | |
self.H = tmp_mask.shape[0] | |
self.W = tmp_mask.shape[1] | |
if self.single_object_mode: | |
self.obj_id = ( | |
int(video_png_root.split("/")[-1]) + 1 | |
) # offset by 1 as bg is 0 | |
else: | |
self.obj_id = None | |
def load(self, frame_id): | |
if self.single_object_mode: | |
return self._load_single_png(frame_id) | |
else: | |
return self._load_multiple_pngs(frame_id) | |
def _load_single_png(self, frame_id): | |
""" | |
load single png from the disk (path: f'{self.obj_id}/{frame_id:05d}.png') | |
Args: | |
frame_id: int, define the mask path | |
Return: | |
binary_segments: dict | |
""" | |
mask_path = os.path.join(self.video_png_root, f"{frame_id:05d}.png") | |
binary_segments = {} | |
if os.path.exists(mask_path): | |
mask = np.array(PILImage.open(mask_path)) | |
else: | |
# if png doesn't exist, empty mask | |
mask = np.zeros((self.H, self.W), dtype=bool) | |
binary_segments[self.obj_id] = torch.from_numpy(mask > 0) | |
return binary_segments | |
def _load_multiple_pngs(self, frame_id): | |
""" | |
load multiple png masks from the disk (path: f'{obj_id}/{frame_id:05d}.png') | |
Args: | |
frame_id: int, define the mask path | |
Return: | |
binary_segments: dict | |
""" | |
# get the path | |
all_objects = sorted(glob.glob(os.path.join(self.video_png_root, "*"))) | |
num_objects = len(all_objects) | |
assert num_objects > 0 | |
# load the masks | |
binary_segments = {} | |
for obj_folder in all_objects: | |
# obj_folder is {video_name}/{obj_id}, obj_id is specified by the name of the folder | |
obj_id = int(obj_folder.split("/")[-1]) | |
obj_id = obj_id + 1 # offset 1 as bg is 0 | |
mask_path = os.path.join(obj_folder, f"{frame_id:05d}.png") | |
if os.path.exists(mask_path): | |
mask = np.array(PILImage.open(mask_path)) | |
else: | |
mask = np.zeros((self.H, self.W), dtype=bool) | |
binary_segments[obj_id] = torch.from_numpy(mask > 0) | |
return binary_segments | |
def __len__(self): | |
return | |
class LazySegments: | |
""" | |
Only decodes segments that are actually used. | |
""" | |
def __init__(self): | |
self.segments = {} | |
self.cache = {} | |
def __setitem__(self, key, item): | |
self.segments[key] = item | |
def __getitem__(self, key): | |
if key in self.cache: | |
return self.cache[key] | |
rle = self.segments[key] | |
mask = torch.from_numpy(mask_utils.decode([rle])).permute(2, 0, 1)[0] | |
self.cache[key] = mask | |
return mask | |
def __contains__(self, key): | |
return key in self.segments | |
def __len__(self): | |
return len(self.segments) | |
def keys(self): | |
return self.segments.keys() | |
class SA1BSegmentLoader: | |
def __init__( | |
self, | |
video_mask_path, | |
mask_area_frac_thresh=1.1, | |
video_frame_path=None, | |
uncertain_iou=-1, | |
): | |
with open(video_mask_path, "r") as f: | |
self.frame_annots = json.load(f) | |
if mask_area_frac_thresh <= 1.0: | |
# Lazily read frame | |
orig_w, orig_h = PILImage.open(video_frame_path).size | |
area = orig_w * orig_h | |
self.frame_annots = self.frame_annots["annotations"] | |
rle_masks = [] | |
for frame_annot in self.frame_annots: | |
if not frame_annot["area"] > 0: | |
continue | |
if ("uncertain_iou" in frame_annot) and ( | |
frame_annot["uncertain_iou"] < uncertain_iou | |
): | |
# uncertain_iou is stability score | |
continue | |
if ( | |
mask_area_frac_thresh <= 1.0 | |
and (frame_annot["area"] / area) >= mask_area_frac_thresh | |
): | |
continue | |
rle_masks.append(frame_annot["segmentation"]) | |
self.segments = LazySegments() | |
for i, rle in enumerate(rle_masks): | |
self.segments[i] = rle | |
def load(self, frame_idx): | |
return self.segments | |