scdrand23's picture
not working version
814a594
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Xueyan Zou (xueyan@cs.wisc.edu)
# --------------------------------------------------------
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
from typing import List, Optional, Tuple, Any
import torch
import torchvision
from torch import nn, Tensor, device
import torch.distributed as dist
import torch.nn.functional as F
from detectron2.layers import cat, shapes_to_tensor
from utilities.constants import *
def pad_arbitrary_tensors(tensors, padding_value=0.):
max_len = torch.stack([torch.tensor(x.shape) for x in tensors]).max(dim=0)[0]
padded_tensor = torch.empty([len(tensors)] + max_len.tolist(), device=tensors[0].device).fill_(padding_value)
for i, x in enumerate(tensors):
padded_tensor[i, :x.shape[0], :x.shape[1]] = x
return padded_tensor
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], : img.shape[2]] = False
elif tensor_list[0].ndim == 2:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(txt.shape) for txt in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, l = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, l), dtype=torch.bool, device=device)
for txt, pad_txt, m in zip(tensor_list, tensor, mask):
pad_txt[: txt.shape[0], : txt.shape[1]] = txt
m[: txt.shape[1]] = False
else:
raise ValueError("not supported")
return NestedTensor(tensor, mask)
def _collate_and_pad_divisibility(tensor_list: list, div=32):
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32)
).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
c,h,w = max_size
pad_h = (div - h % div) if h % div != 0 else 0
pad_w = (div - w % div) if w % div != 0 else 0
max_size = (c,h+pad_h,w+pad_w)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
return padded_imgs
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
# TODO: add background to
def get_class_names(name):
if name is None:
return None
elif 'refcoco' in name:
return ["background"]
elif 'biomed' in name:
return BIOMED_CLASSES + ["background"]
elif 'med_sam' in name:
### MedSAM class names
medsam_classes = ['liver', 'lung', 'pancreas', 'stomach', 'heart', 'gallbladder', 'prostate', 'brain ventricles', 'cerebellum',
'left heart ventricle', 'right heart ventricle', 'vessel', 'polyp', 'surgical tool', 'pleural effusion', 'infection', 'gland', 'tumor']
return medsam_classes + ["background"]
elif 'coco' in name:
return COCO_PANOPTIC_CLASSES + ["background"]
elif 'ade20k_full' in name:
return ADE20K_847 + ["background"]
elif 'ade' in name:
return ADE_PANOPTIC_CLASSES + ["background"]
elif 'scannet_41' in name:
return SCAN_40 + ["background"]
elif 'scannet_21' in name:
return SCAN_20 + ["background"]
elif 'sun' in name:
return SUN_RGBD_37 + ["background"]
elif 'voc' in name:
return PASCAL_CLASSES + ["background"]
elif name == 'cityscapes_fine_sem_seg_val':
return CITYSCAPES + ["background"]
elif name == 'cityscapes_fine_instance_seg_val':
return CITYSCAPES_THING + ["background"]
elif name in ['cityscapes_fine_panoptic_val']:
return CITYSCAPES + ["background"]
elif name == 'bdd10k_val_sem_seg':
return BDD_SEM + ["background"]
elif name == 'bdd10k_40_panoptic_val':
return BDD_PANO + ["background"]
elif 'vlp' in name:
return ["background"]
else:
assert False, "text dataset name {} is not defined".format(name)
def get_iou(gt_masks, pred_masks, ignore_label=-1):
rev_ignore_mask = ~(gt_masks == ignore_label)
gt_masks = gt_masks.bool()
n,h,w = gt_masks.shape
intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1)
ious = (intersection / union)
return ious
class Spatial_ImageList(object):
"""
Structure that holds a list of images (of possibly
varying sizes) as a single tensor.
This works by padding the images to the same size.
The original sizes of each image is stored in `image_sizes`.
Attributes:
image_sizes (list[tuple[int, int]]): each tuple is (h, w).
During tracing, it becomes list[Tensor] instead.
"""
def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]):
"""
Arguments:
tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1
image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can
be smaller than (H, W) due to padding.
"""
self.tensor = tensor
self.image_sizes = image_sizes
def __len__(self) -> int:
return len(self.image_sizes)
def __getitem__(self, idx) -> torch.Tensor:
"""
Access the individual image in its original size.
Args:
idx: int or slice
Returns:
Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1
"""
size = self.image_sizes[idx]
return self.tensor[idx, ..., : size[0], : size[1]]
@torch.jit.unused
def to(self, *args: Any, **kwargs: Any) -> "Spatial_ImageList":
cast_tensor = self.tensor.to(*args, **kwargs)
return Spatial_ImageList(cast_tensor, self.image_sizes)
@property
def device(self) -> device:
return self.tensor.device
@staticmethod
def from_tensors(
tensors: List[torch.Tensor], size_divisibility: int = 0, pad_value: float = 0.0
) -> "Spatial_ImageList":
"""
Args:
tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or
(C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded
to the same shape with `pad_value`.
size_divisibility (int): If `size_divisibility > 0`, add padding to ensure
the common height and width is divisible by `size_divisibility`.
This depends on the model and many models need a divisibility of 32.
pad_value (float): value to pad
Returns:
an `Spatial_ImageList`.
"""
assert len(tensors) > 0
assert isinstance(tensors, (tuple, list))
for t in tensors:
assert isinstance(t, torch.Tensor), type(t)
image_sizes = [(im.shape[-3], im.shape[-2], im.shape[-1]) for im in tensors]
image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes]
max_size = torch.stack(image_sizes_tensor).max(0).values
if size_divisibility > 1:
stride = size_divisibility
# the last two dims are H,W, both subject to divisibility requirement
max_size[-2:] = (max_size[-2:] + (stride - 1)).div(stride, rounding_mode="floor") * stride
# handle weirdness of scripting and tracing ...
if torch.jit.is_scripting():
max_size: List[int] = max_size.to(dtype=torch.long).tolist()
else:
if torch.jit.is_tracing():
image_sizes = image_sizes_tensor
if len(tensors) == 1:
# This seems slightly (2%) faster.
# TODO: check whether it's faster for multiple images as well
image_size = image_sizes[0]
padding_size = [0, max_size[-1] - image_size[2], 0, max_size[-2] - image_size[1]]
batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0)
else:
# max_size can be a tensor in tracing mode, therefore convert to list
batch_shape = [len(tensors)] + list(tensors[0].shape[:-3]) + list(max_size)
batched_imgs = tensors[0].new_full(batch_shape, pad_value)
for img, pad_img in zip(tensors, batched_imgs):
pad_img[:img.shape[-3],:img.shape[-2],:img.shape[-1]].copy_(img)
return Spatial_ImageList(batched_imgs.contiguous(), image_sizes)