naderasadi's picture
Initial commit
5b2ab1c
raw
history blame
4.58 kB
from typing import Any, List, Optional, Tuple, Union
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
MaskFormerImageProcessor,
MaskFormerForInstanceSegmentation,
)
class MaskFormer:
"""MaskFormer semantic segmentation model.
Args:
model_size (str, optional):
Size of the MaskFormer model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = MaskFormerImageProcessor.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
self.model = MaskFormerForInstanceSegmentation.from_pretrained(
f"facebook/maskformer-swin-{model_size}-ade"
)
def process(self, images: List[Image.Image]):
inputs = self.processor(images=images, return_tensors="pt")
outputs = self.model(**inputs)
# model predicts class_queries_logits of shape `(batch_size, num_queries)`
# and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
class_queries_logits = outputs.class_queries_logits
masks_queries_logits = outputs.masks_queries_logits
# you can pass them to processor for postprocessing
# we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
predicted_semantic_maps = self.processor.post_process_semantic_segmentation(
outputs, target_sizes=[images[0].size[::-1] * len(images)]
)
return predicted_semantic_maps
class Mask2Former(MaskFormer):
"""Mask2Former semantic segmentation model.
Args:
model_size (str, optional):
Size of the Mask2Former model. Defaults to "large".
"""
def __init__(self, model_size: Optional[str] = "large") -> None:
assert model_size in [
"tiny",
"base",
"large",
], "Model size must be one of 'tiny', 'base', or 'large'"
self.processor = AutoImageProcessor.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
self.model = Mask2FormerForUniversalSegmentation.from_pretrained(
f"facebook/mask2former-swin-{model_size}-ade-semantic"
)
# class ADESegmentation:
# def __init__(self, model_name: str):
# self.processor = MODEL_DICT[model_name]["processor"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# self.model = MODEL_DICT[model_name]["model"].from_pretrained(
# MODEL_DICT[model_name]["name"]
# )
# def predict(self, image: Image.Image):
# inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# # model predicts class_queries_logits of shape `(batch_size, num_queries)`
# # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
# class_queries_logits = outputs.class_queries_logits
# masks_queries_logits = outputs.masks_queries_logits
# # you can pass them to processor for postprocessing
# # we refer to the demo notebooks for visualization (see "Resources" section in the MaskFormer docs)
# predicted_semantic_maps = processor.post_process_semantic_segmentation(
# outputs, target_sizes=[image.size[::-1]]
# )
# return predicted_semantic_maps
# def get_mask(self, predicted_semantic_maps, class_id: int):
# masks, labels, obj_names = get_masks_from_segmentation_map(
# predicted_semantic_maps[0]
# )
# mask = masks[labels.index(ID)]
# object_mask = np.logical_not(mask).astype(int)
# mask = torch.Tensor(mask).repeat(3, 1, 1)
# object_mask = torch.Tensor(object_mask).repeat(3, 1, 1)
# return mask, object_mask
# def get_PIL_mask(self, predicted_semantic_maps, class_id: int):
# mask, object_mask = self.get_mask(predicted_semantic_maps[0], class_id=class_id)
# mask = transforms.ToPILImage()(mask)
# object_mask = transforms.ToPILImage()(object_mask)
# return mask, object_mask
# def get_PIL_segmentation_map(self, predicted_semantic_maps):
# return visualize_segmentation_map(predicted_semantic_maps[0])