import torch from torchvision import models, transforms from PIL import Image import gradio as gr from typing import Union class Preprocessor: def __init__(self): """ Initialize the preprocessing transformations. """ self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, image: Image.Image) -> torch.Tensor: """ Apply preprocessing to the input image. :param image: Input image to be preprocessed. :return: Preprocessed image as a tensor. """ return self.transform(image) class SegmentationModel: def __init__(self): """ Initialize and load the DeepLabV3 ResNet101 model. """ self.model = models.segmentation.deeplabv3_resnet101(pretrained=True) self.model.eval() if torch.cuda.is_available(): self.model.to('cuda') def predict(self, input_batch: torch.Tensor) -> torch.Tensor: """ Perform inference using the model on the input batch. :param input_batch: Batch of preprocessed images. :return: Model output tensor. """ with torch.no_grad(): if torch.cuda.is_available(): input_batch = input_batch.to('cuda') output: torch.Tensor = self.model(input_batch)['out'][0] return output class OutputColorizer: def __init__(self): """ Initialize the color palette for segmentations. """ palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors : torch.Tensor = torch.as_tensor([i for i in range(21)])[:, None] * palette self.colors = (colors % 255).numpy().astype("uint8") def colorize(self, output: torch.Tensor) -> Image.Image: """ Apply colorization to the segmentation output. :param output: Segmentation output tensor. :return: Colorized segmentation image. """ colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P') colorized_output.putpalette(self.colors.ravel()) return colorized_output class Segmenter: def __init__(self): """ Initialize the Segmenter with Preprocessor, SegmentationModel, and OutputColorizer. """ self.preprocessor = Preprocessor() self.model = SegmentationModel() self.colorizer = OutputColorizer() def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image: """ Perform the complete segmentation process on the input image. :param image: Input image to be segmented. :return: Colorized segmentation image. """ input_image: Image.Image = image.convert("RGB") input_tensor: torch.Tensor = self.preprocessor(input_image) input_batch: torch.Tensor = input_tensor.unsqueeze(0) output: torch.Tensor = self.model.predict(input_batch) output_predictions: torch.Tensor = output.argmax(0) return self.colorizer.colorize(output_predictions) segmenter = Segmenter() interface = gr.Interface( fn=segmenter.segment, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Deeplabv3 Segmentation", description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101." ) if __name__ == "__main__": interface.launch()