Spaces:
Runtime error
Runtime error
File size: 5,795 Bytes
6435d5a dd3bd9a 6435d5a 31361ed 6435d5a dd3bd9a 6435d5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import streamlit as st
# from transformers import AutoProcessor, AutoModelForMaskGeneration
from transformers import SamModel, SamProcessor
from transformers import pipeline
from PIL import Image, ImageOps
# from PIL import Image
import numpy as np
# import matplotlib.pyplot as plt
import torch
import requests
from io import BytesIO
def main():
st.title("Image Segmentation with Object Detection")
# Introduction and How-to
st.markdown("""
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works:
- **Upload an image**: Drag and drop or use the browse files option.
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes.
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data.
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis.
Please note that processing takes some time. We appreciate your patience as the models do their work!
""")
# Model credits
st.subheader("Powered by:")
st.write("- Object Detection Model: `facebook/detr-resnet-50`")
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
# Load SAM by Facebook
# processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
# model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-huge")
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
# Load Object Detection
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
xs_ys = [(2.0, 2.0), (2.5, 2.5)] #, (2.5, 2.0), (2.0, 2.5), (1.5, 1.5)]
alpha = 20
width = 600
if uploaded_file is not None:
raw_image = Image.open(uploaded_file)
st.subheader("Uploaded Image")
st.image(raw_image, caption="Uploaded Image", width=width)
### STEP 1. Object Detection
pipeline_output = od_pipe(raw_image)
# Convert the bounding boxes from the pipeline output into the expected format for the SAM processor
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
labels_format = [b['label'] for b in pipeline_output]
print(input_boxes_format)
print(labels_format)
# Now use these formatted boxes with the processor
for b, l in zip(input_boxes_format, labels_format):
with st.spinner('Processing...'):
st.subheader(f'bounding box : {l}')
inputs = processor(images=raw_image,
input_boxes=[b],
return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
predicted_mask = predicted_masks[0]
for i in range(0, 3):
# 2D array (boolean mask)
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
# Apply the mask to the image
# Convert mask to a 3-channel image if your base image is in RGB
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
#display the final image
st.image(final_image, caption=f"Masked Image {i+1}", width=width)
###
for (x, y) in xs_ys:
with st.spinner('Processing...'):
# Calculate input points
point_x = raw_image.size[0] // x
point_y = raw_image.size[1] // y
input_points = [[[ point_x, point_y ]]]
# Prepare inputs
inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
# Generate masks
with torch.no_grad():
outputs = model(**inputs)
# Post-process masks
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
predicted_mask = predicted_masks[0]
# Display masked images
st.subheader(f"Input points : ({1/x},{1/y})")
for i in range(3):
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
###
mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
st.image(final_image, caption=f"Masked Image {i+1}", width=width)
if __name__ == "__main__":
main() |