etemkocaaslan's picture
Create app.py
f252cc2 verified
raw
history blame
3.49 kB
import torch
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from typing import Union
class Preprocessor:
def __init__(self):
"""
Initialize the preprocessing transformations.
"""
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, image: Image.Image) -> torch.Tensor:
"""
Apply preprocessing to the input image.
:param image: Input image to be preprocessed.
:return: Preprocessed image as a tensor.
"""
return self.transform(image)
class SegmentationModel:
def __init__(self):
"""
Initialize and load the DeepLabV3 ResNet101 model.
"""
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model.to('cuda')
def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
"""
Perform inference using the model on the input batch.
:param input_batch: Batch of preprocessed images.
:return: Model output tensor.
"""
with torch.no_grad():
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
output: torch.Tensor = self.model(input_batch)['out'][0]
return output
class OutputColorizer:
def __init__(self):
"""
Initialize the color palette for segmentations.
"""
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def colorize(self, output: torch.Tensor) -> Image.Image:
"""
Apply colorization to the segmentation output.
:param output: Segmentation output tensor.
:return: Colorized segmentation image.
"""
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmenter:
def __init__(self):
"""
Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer.
"""
self.preprocessor = Preprocessor()
self.model = SegmentationModel()
self.colorizer = OutputColorizer()
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
"""
Perform the complete segmentation process on the input image.
:param image: Input image to be segmented.
:return: Colorized segmentation image.
"""
input_image: Image.Image = image.convert("RGB")
input_tensor: torch.Tensor = self.preprocessor(input_image)
input_batch: torch.Tensor = input_tensor.unsqueeze(0)
output: torch.Tensor = self.model.predict(input_batch)
output_predictions: torch.Tensor = output.argmax(0)
return self.colorizer.colorize(output_predictions)
segmenter = Segmenter()
interface = gr.Interface(
fn=segmenter.segment,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Deeplabv3 Segmentation",
description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
)
if __name__ == "__main__":
interface.launch()