|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import gradio as gr |
|
from typing import Union |
|
|
|
class Preprocessor: |
|
def __init__(self): |
|
""" |
|
Initialize the preprocessing transformations. |
|
""" |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def __call__(self, image: Image.Image) -> torch.Tensor: |
|
""" |
|
Apply preprocessing to the input image. |
|
|
|
:param image: Input image to be preprocessed. |
|
:return: Preprocessed image as a tensor. |
|
""" |
|
return self.transform(image) |
|
|
|
class SegmentationModel: |
|
def __init__(self): |
|
""" |
|
Initialize and load the DeepLabV3 ResNet101 model. |
|
""" |
|
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True) |
|
self.model.eval() |
|
if torch.cuda.is_available(): |
|
self.model.to('cuda') |
|
|
|
def predict(self, input_batch: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Perform inference using the model on the input batch. |
|
|
|
:param input_batch: Batch of preprocessed images. |
|
:return: Model output tensor. |
|
""" |
|
with torch.no_grad(): |
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
output: torch.Tensor = self.model(input_batch)['out'][0] |
|
return output |
|
|
|
class OutputColorizer: |
|
def __init__(self): |
|
""" |
|
Initialize the color palette for segmentations. |
|
""" |
|
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) |
|
colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette |
|
self.colors = (colors % 255).numpy().astype("uint8") |
|
|
|
def colorize(self, output: torch.Tensor) -> Image.Image: |
|
""" |
|
Apply colorization to the segmentation output. |
|
|
|
:param output: Segmentation output tensor. |
|
:return: Colorized segmentation image. |
|
""" |
|
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P') |
|
colorized_output.putpalette(self.colors.ravel()) |
|
return colorized_output |
|
|
|
class Segmenter: |
|
def __init__(self): |
|
""" |
|
Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer. |
|
""" |
|
self.preprocessor = Preprocessor() |
|
self.model = SegmentationModel() |
|
self.colorizer = OutputColorizer() |
|
|
|
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image: |
|
""" |
|
Perform the complete segmentation process on the input image. |
|
|
|
:param image: Input image to be segmented. |
|
:return: Colorized segmentation image. |
|
""" |
|
input_image: Image.Image = image.convert("RGB") |
|
input_tensor: torch.Tensor = self.preprocessor(input_image) |
|
input_batch: torch.Tensor = input_tensor.unsqueeze(0) |
|
output: torch.Tensor = self.model.predict(input_batch) |
|
output_predictions: torch.Tensor = output.argmax(0) |
|
return self.colorizer.colorize(output_predictions) |
|
|
|
segmenter = Segmenter() |
|
|
|
interface = gr.Interface( |
|
fn=segmenter.segment, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Deeplabv3 Segmentation", |
|
description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |