import torch from torchvision import models, transforms from PIL import Image import gradio as gr from typing import Union class Preprocessor: def __init__(self): self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, image: Image.Image) -> torch.Tensor: return self.transform(image) class SegmentationModel: def __init__(self): self.model = models.segmentation.deeplabv3_resnet101(pretrained=True) self.model.eval() if torch.cuda.is_available(): self.model.to('cuda') def predict(self, input_batch: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if torch.cuda.is_available(): input_batch = input_batch.to('cuda') output: torch.Tensor = self.model(input_batch)['out'][0] return output class OutputColorizer: def __init__(self): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors = torch.as_tensor([i for i in range(21)])[:, None] * palette self.colors = (colors % 255).numpy().astype("uint8") def colorize(self, output: torch.Tensor) -> Image.Image: colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P') colorized_output.putpalette(self.colors.ravel()) return colorized_output class Segmenter: def __init__(self): self.preprocessor = Preprocessor() self.model = SegmentationModel() self.colorizer = OutputColorizer() def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image: input_image: Image.Image = image.convert("RGB") input_tensor: torch.Tensor = self.preprocessor(input_image) input_batch: torch.Tensor = input_tensor.unsqueeze(0) output: torch.Tensor = self.model.predict(input_batch) output_predictions: torch.Tensor = output.argmax(0) return self.colorizer.colorize(output_predictions) class GradioApp: def __init__(self, segmenter: Segmenter): self.segmenter = segmenter def launch(self): with gr.Blocks() as demo: gr.Markdown("

Deeplabv3 Segmentation

") gr.Markdown("

Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.

") gr.Markdown(""" ### Model Information **DeepLabv3 with ResNet101** is a convolutional neural network model designed for semantic image segmentation. It utilizes atrous convolution to capture multi-scale context by using different atrous rates. """) with gr.Row(): with gr.Column(): image_input = gr.Image(type='pil', label="Input Image", show_label=False) with gr.Column(): image_output = gr.Image(type='pil', label="Segmented Output", show_label=False) button = gr.Button("Segment") button.click(fn=self.segmenter.segment, inputs=image_input, outputs=image_output) gr.Markdown("### Example Images") gr.Examples( examples=[ ["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"], ["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"], ["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"] ], inputs=image_input, outputs=image_output, label="Click an example to use it" ) demo.launch() if __name__ == "__main__": segmenter = Segmenter() app = GradioApp(segmenter) app.launch()