File size: 2,391 Bytes
f252cc2 cea72b9 f252cc2 cea72b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from typing import Union
class Preprocessor:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.transform(image)
class SegmentationModel:
def __init__(self):
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model.to('cuda')
def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
output: torch.Tensor = self.model(input_batch)['out'][0]
return output
class OutputColorizer:
def __init__(self):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def colorize(self, output: torch.Tensor) -> Image.Image:
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmenter:
def __init__(self):
self.preprocessor = Preprocessor()
self.model = SegmentationModel()
self.colorizer = OutputColorizer()
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
input_image: Image.Image = image.convert("RGB")
input_tensor: torch.Tensor = self.preprocessor(input_image)
input_batch: torch.Tensor = input_tensor.unsqueeze(0)
output: torch.Tensor = self.model.predict(input_batch)
output_predictions: torch.Tensor = output.argmax(0)
return self.colorizer.colorize(output_predictions)
segmenter = Segmenter()
interface = gr.Interface(
fn=segmenter.segment,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Deeplabv3 Segmentation",
description="Upload an image to perform semantic segmentation using Deeplabv3 ResNet101."
)
if __name__ == "__main__":
interface.launch()
|