File size: 3,970 Bytes
f252cc2 cea72b9 f252cc2 9d464be f252cc2 9d464be f252cc2 cea72b9 9d464be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import torch
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from typing import Union
class Preprocessor:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.transform(image)
class SegmentationModel:
def __init__(self):
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model.to('cuda')
def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
output: torch.Tensor = self.model(input_batch)['out'][0]
return output
class OutputColorizer:
def __init__(self):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def colorize(self, output: torch.Tensor) -> Image.Image:
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmenter:
def __init__(self):
self.preprocessor = Preprocessor()
self.model = SegmentationModel()
self.colorizer = OutputColorizer()
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
input_image: Image.Image = image.convert("RGB")
input_tensor: torch.Tensor = self.preprocessor(input_image)
input_batch: torch.Tensor = input_tensor.unsqueeze(0)
output: torch.Tensor = self.model.predict(input_batch)
output_predictions: torch.Tensor = output.argmax(0)
return self.colorizer.colorize(output_predictions)
class GradioApp:
def __init__(self, segmenter: Segmenter):
self.segmenter = segmenter
def launch(self):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
gr.Markdown("""
### Model Information
**DeepLabv3 with ResNet101** is a convolutional neural network model designed for semantic image segmentation.
It utilizes atrous convolution to capture multi-scale context by using different atrous rates.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type='pil', label="Input Image", show_label=False)
with gr.Column():
image_output = gr.Image(type='pil', label="Segmented Output", show_label=False)
button = gr.Button("Segment")
button.click(fn=self.segmenter.segment, inputs=image_input, outputs=image_output)
gr.Markdown("### Example Images")
gr.Examples(
examples=[
["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
],
inputs=image_input,
outputs=image_output,
label="Click an example to use it"
)
demo.launch()
if __name__ == "__main__":
segmenter = Segmenter()
app = GradioApp(segmenter)
app.launch() |