import streamlit as st import requests import torch from PIL import Image from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation import numpy as np import matplotlib.pyplot as plt # Load Mask2Former fine-tuned on ADE20k semantic segmentation st.title("Mask2Former Semantic Segmentation") st.write("Upload an image to perform semantic segmentation using Mask2Former.") processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic") def segment_image(image: Image.Image): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] return predicted_semantic_map def visualize_segmentation(image: Image.Image, segmentation_map): plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.imshow(image) plt.axis("off") plt.title("Original Image") plt.subplot(1, 2, 2) plt.imshow(segmentation_map, cmap="jet", alpha=0.7) plt.axis("off") plt.title("Segmented Image") st.pyplot(plt) # File uploader for user to upload an image uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_file: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) if st.button("Segment Image"): st.write("Processing the image...") segmentation_map = segment_image(image) visualize_segmentation(image, segmentation_map.numpy()) # Option to test with a sample image if st.button("Use Sample Image"): url = "http://images.cocodataset.org/val2017/000000039769.jpg" image = Image.open(requests.get(url, stream=True).raw) st.image(image, caption="Sample Image", use_column_width=True) st.write("Processing the image...") segmentation_map = segment_image(image) visualize_segmentation(image, segmentation_map.numpy())