|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import gradio as gr |
|
from typing import Union |
|
|
|
class Preprocessor: |
|
def __init__(self): |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def __call__(self, image: Image.Image) -> torch.Tensor: |
|
return self.transform(image) |
|
|
|
class SegmentationModel: |
|
def __init__(self): |
|
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True) |
|
self.model.eval() |
|
if torch.cuda.is_available(): |
|
self.model.to('cuda') |
|
|
|
def predict(self, input_batch: torch.Tensor) -> torch.Tensor: |
|
with torch.no_grad(): |
|
if torch.cuda.is_available(): |
|
input_batch = input_batch.to('cuda') |
|
output: torch.Tensor = self.model(input_batch)['out'][0] |
|
return output |
|
|
|
class OutputColorizer: |
|
def __init__(self): |
|
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) |
|
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette |
|
self.colors = (colors % 255).numpy().astype("uint8") |
|
|
|
def colorize(self, output: torch.Tensor) -> Image.Image: |
|
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P') |
|
colorized_output.putpalette(self.colors.ravel()) |
|
return colorized_output |
|
|
|
class Segmenter: |
|
def __init__(self): |
|
self.preprocessor = Preprocessor() |
|
self.model = SegmentationModel() |
|
self.colorizer = OutputColorizer() |
|
|
|
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image: |
|
input_image: Image.Image = image.convert("RGB") |
|
input_tensor: torch.Tensor = self.preprocessor(input_image) |
|
input_batch: torch.Tensor = input_tensor.unsqueeze(0) |
|
output: torch.Tensor = self.model.predict(input_batch) |
|
output_predictions: torch.Tensor = output.argmax(0) |
|
return self.colorizer.colorize(output_predictions) |
|
|
|
class GradioApp: |
|
def __init__(self, segmenter: Segmenter): |
|
self.segmenter = segmenter |
|
|
|
def launch(self): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>") |
|
gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>") |
|
gr.Markdown(""" |
|
### Model Information |
|
**DeepLabv3 with ResNet101** is a convolutional neural network model designed for semantic image segmentation. |
|
It utilizes atrous convolution to capture multi-scale context by using different atrous rates. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type='pil', label="Input Image", show_label=False) |
|
with gr.Column(): |
|
image_output = gr.Image(type='pil', label="Segmented Output", show_label=False) |
|
|
|
button = gr.Button("Segment") |
|
button.click(fn=self.segmenter.segment, inputs=image_input, outputs=image_output) |
|
|
|
gr.Markdown("### Example Images") |
|
gr.Examples( |
|
examples=[ |
|
["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"], |
|
["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"], |
|
["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"] |
|
], |
|
inputs=image_input, |
|
outputs=image_output, |
|
label="Click an example to use it" |
|
) |
|
|
|
demo.launch() |
|
|
|
if __name__ == "__main__": |
|
segmenter = Segmenter() |
|
app = GradioApp(segmenter) |
|
app.launch() |