File size: 3,492 Bytes
f252cc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from typing import Union
class Preprocessor:
def __init__(self):
"""
Initialize the preprocessing transformations.
"""
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, image: Image.Image) -> torch.Tensor:
"""
Apply preprocessing to the input image.
:param image: Input image to be preprocessed.
:return: Preprocessed image as a tensor.
"""
return self.transform(image)
class SegmentationModel:
def __init__(self):
"""
Initialize and load the DeepLabV3 ResNet101 model.
"""
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model.to('cuda')
def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
"""
Perform inference using the model on the input batch.
:param input_batch: Batch of preprocessed images.
:return: Model output tensor.
"""
with torch.no_grad():
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
output: torch.Tensor = self.model(input_batch)['out'][0]
return output
class OutputColorizer:
def __init__(self):
"""
Initialize the color palette for segmentations.
"""
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def colorize(self, output: torch.Tensor) -> Image.Image:
"""
Apply colorization to the segmentation output.
:param output: Segmentation output tensor.
:return: Colorized segmentation image.
"""
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmenter:
def __init__(self):
"""
Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer.
"""
self.preprocessor = Preprocessor()
self.model = SegmentationModel()
self.colorizer = OutputColorizer()
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
"""
Perform the complete segmentation process on the input image.
:param image: Input image to be segmented.
:return: Colorized segmentation image.
"""
input_image: Image.Image = image.convert("RGB")
input_tensor: torch.Tensor = self.preprocessor(input_image)
input_batch: torch.Tensor = input_tensor.unsqueeze(0)
output: torch.Tensor = self.model.predict(input_batch)
output_predictions: torch.Tensor = output.argmax(0)
return self.colorizer.colorize(output_predictions)
segmenter = Segmenter()
interface = gr.Interface(
fn=segmenter.segment,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Deeplabv3 Segmentation",
description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
)
if __name__ == "__main__":
interface.launch() |