import torch from torchvision import models, transforms from PIL import Image import gradio as gr from typing import Union class Preprocessor: def __init__(self): self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __call__(self, image: Image.Image) -> torch.Tensor: return self.transform(image) class SegmentationModel: def __init__(self): self.model = models.segmentation.deeplabv3_resnet101(pretrained=True) self.model.eval() if torch.cuda.is_available(): self.model.to('cuda') def predict(self, input_batch: torch.Tensor) -> torch.Tensor: with torch.no_grad(): if torch.cuda.is_available(): input_batch = input_batch.to('cuda') output: torch.Tensor = self.model(input_batch)['out'][0] return output class OutputColorizer: def __init__(self): palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) colors = torch.as_tensor([i for i in range(21)])[:, None] * palette self.colors = (colors % 255).numpy().astype("uint8") def colorize(self, output: torch.Tensor) -> Image.Image: colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P') colorized_output.putpalette(self.colors.ravel()) return colorized_output class Segmenter: def __init__(self): self.preprocessor = Preprocessor() self.model = SegmentationModel() self.colorizer = OutputColorizer() def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image: input_image: Image.Image = image.convert("RGB") input_tensor: torch.Tensor = self.preprocessor(input_image) input_batch: torch.Tensor = input_tensor.unsqueeze(0) output: torch.Tensor = self.model.predict(input_batch) output_predictions: torch.Tensor = output.argmax(0) return self.colorizer.colorize(output_predictions) segmenter = Segmenter() interface = gr.Interface( fn=segmenter.segment, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Deeplabv3 Segmentation", description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101." ) if __name__ == "__main__": interface.launch()