etemkocaaslan's picture
Update app.py
9d464be verified
raw
history blame
3.97 kB
import torch
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from typing import Union
class Preprocessor:
def __init__(self):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def __call__(self, image: Image.Image) -> torch.Tensor:
return self.transform(image)
class SegmentationModel:
def __init__(self):
self.model = models.segmentation.deeplabv3_resnet101(pretrained=True)
self.model.eval()
if torch.cuda.is_available():
self.model.to('cuda')
def predict(self, input_batch: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
output: torch.Tensor = self.model(input_batch)['out'][0]
return output
class OutputColorizer:
def __init__(self):
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
self.colors = (colors % 255).numpy().astype("uint8")
def colorize(self, output: torch.Tensor) -> Image.Image:
colorized_output = Image.fromarray(output.byte().cpu().numpy(), mode='P')
colorized_output.putpalette(self.colors.ravel())
return colorized_output
class Segmenter:
def __init__(self):
self.preprocessor = Preprocessor()
self.model = SegmentationModel()
self.colorizer = OutputColorizer()
def segment(self, image: Union[Image.Image, torch.Tensor]) -> Image.Image:
input_image: Image.Image = image.convert("RGB")
input_tensor: torch.Tensor = self.preprocessor(input_image)
input_batch: torch.Tensor = input_tensor.unsqueeze(0)
output: torch.Tensor = self.model.predict(input_batch)
output_predictions: torch.Tensor = output.argmax(0)
return self.colorizer.colorize(output_predictions)
class GradioApp:
def __init__(self, segmenter: Segmenter):
self.segmenter = segmenter
def launch(self):
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>Deeplabv3 Segmentation</h1>")
gr.Markdown("<p style='text-align: center;'>Upload an image to perform semantic segmentation using Deeplabv3 ResNet101.</p>")
gr.Markdown("""
### Model Information
**DeepLabv3 with ResNet101** is a convolutional neural network model designed for semantic image segmentation.
It utilizes atrous convolution to capture multi-scale context by using different atrous rates.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type='pil', label="Input Image", show_label=False)
with gr.Column():
image_output = gr.Image(type='pil', label="Segmented Output", show_label=False)
button = gr.Button("Segment")
button.click(fn=self.segmenter.segment, inputs=image_input, outputs=image_output)
gr.Markdown("### Example Images")
gr.Examples(
examples=[
["https://www.timeforkids.com/wp-content/uploads/2024/01/Snapshot_20240126.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2023/09/G3G5_230915_puffins_on_the_rise.jpg?w=1024"],
["https://www.timeforkids.com/wp-content/uploads/2024/03/G3G5_240412_bug_eyed.jpg?w=1024"]
],
inputs=image_input,
outputs=image_output,
label="Click an example to use it"
)
demo.launch()
if __name__ == "__main__":
segmenter = Segmenter()
app = GradioApp(segmenter)
app.launch()