Tahir5 commited on
Commit
c4b62dd
·
verified ·
1 Parent(s): 3f30ad1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Load Mask2Former fine-tuned on ADE20k semantic segmentation
10
+ st.title("Mask2Former Semantic Segmentation")
11
+ st.write("Upload an image to perform semantic segmentation using Mask2Former.")
12
+
13
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
14
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
15
+
16
+ def segment_image(image: Image.Image):
17
+ inputs = processor(images=image, return_tensors="pt")
18
+ with torch.no_grad():
19
+ outputs = model(**inputs)
20
+ predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
21
+ return predicted_semantic_map
22
+
23
+ def visualize_segmentation(image: Image.Image, segmentation_map):
24
+ plt.figure(figsize=(10, 5))
25
+ plt.subplot(1, 2, 1)
26
+ plt.imshow(image)
27
+ plt.axis("off")
28
+ plt.title("Original Image")
29
+
30
+ plt.subplot(1, 2, 2)
31
+ plt.imshow(segmentation_map, cmap="jet", alpha=0.7)
32
+ plt.axis("off")
33
+ plt.title("Segmented Image")
34
+
35
+ st.pyplot(plt)
36
+
37
+ # File uploader for user to upload an image
38
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
39
+ if uploaded_file:
40
+ image = Image.open(uploaded_file).convert("RGB")
41
+ st.image(image, caption="Uploaded Image", use_column_width=True)
42
+
43
+ if st.button("Segment Image"):
44
+ st.write("Processing the image...")
45
+ segmentation_map = segment_image(image)
46
+ visualize_segmentation(image, segmentation_map.numpy())
47
+
48
+ # Option to test with a sample image
49
+ if st.button("Use Sample Image"):
50
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
51
+ image = Image.open(requests.get(url, stream=True).raw)
52
+ st.image(image, caption="Sample Image", use_column_width=True)
53
+
54
+ st.write("Processing the image...")
55
+ segmentation_map = segment_image(image)
56
+ visualize_segmentation(image, segmentation_map.numpy())