Tahir5's picture
Create app.py
c4b62dd verified
raw
history blame
2.15 kB
import streamlit as st
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
import numpy as np
import matplotlib.pyplot as plt
# Load Mask2Former fine-tuned on ADE20k semantic segmentation
st.title("Mask2Former Semantic Segmentation")
st.write("Upload an image to perform semantic segmentation using Mask2Former.")
processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
def segment_image(image: Image.Image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
return predicted_semantic_map
def visualize_segmentation(image: Image.Image, segmentation_map):
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis("off")
plt.title("Original Image")
plt.subplot(1, 2, 2)
plt.imshow(segmentation_map, cmap="jet", alpha=0.7)
plt.axis("off")
plt.title("Segmented Image")
st.pyplot(plt)
# File uploader for user to upload an image
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Segment Image"):
st.write("Processing the image...")
segmentation_map = segment_image(image)
visualize_segmentation(image, segmentation_map.numpy())
# Option to test with a sample image
if st.button("Use Sample Image"):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
st.image(image, caption="Sample Image", use_column_width=True)
st.write("Processing the image...")
segmentation_map = segment_image(image)
visualize_segmentation(image, segmentation_map.numpy())