beingcognitive's picture
For Tech Campus class
9843137 verified
raw
history blame
4.55 kB
import streamlit as st
from transformers import SamModel, SamProcessor, pipeline
from PIL import Image, ImageOps
import numpy as np
import torch
# Constants
XS_YS = [(2.0, 2.0), (2.5, 2.5)]
WIDTH = 600
# Load models
@st.cache_resource
def load_models():
model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
return model, processor, od_pipe
def process_image(image, model, processor, bounding_box=None, input_point=None):
try:
# Convert image to RGB mode
image = image.convert('RGB')
# Convert image to numpy array
image_array = np.array(image)
if bounding_box:
inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt")
elif input_point:
inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt")
else:
raise ValueError("Either bounding_box or input_point must be provided")
with torch.no_grad():
outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
outputs.pred_masks,
inputs["original_sizes"],
inputs["reshaped_input_sizes"]
)
return predicted_masks[0]
except Exception as e:
st.error(f"Error processing image: {str(e)}")
return None
def display_masked_images(raw_image, predicted_mask, caption_prefix):
for i in range(3):
mask = predicted_mask[0][i]
int_mask = np.array(mask).astype(int) * 255
mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH)
def main():
st.title("Image Segmentation with Object Detection")
# Introduction and How-to
st.markdown("""
Welcome to the Image Segmentation and Object Detection app, where cutting-edge AI models bring your images to life by identifying and segmenting objects. Here's how it works:
- **Upload an image**: Drag and drop or use the browse files option.
- **Detection**: The `facebook/detr-resnet-50` model detects objects and their bounding boxes.
- **Segmentation**: Following detection, `Zigeng/SlimSAM-uniform-77` segments the objects using the bounding box data.
- **Further Segmentation**: The app also provides additional segmentation insights using input points at positions (0.4, 0.4) and (0.5, 0.5) for a more granular analysis.
Please note that processing takes some time. We appreciate your patience as the models do their work!
""")
# Model credits
st.subheader("Powered by:")
st.write("- Object Detection Model: `facebook/detr-resnet-50`")
st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
model, processor, od_pipe = load_models()
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
raw_image = Image.open(uploaded_file)
st.subheader("Uploaded Image")
st.image(raw_image, caption="Uploaded Image", width=WIDTH)
with st.spinner('Processing image...'):
# Object Detection
pipeline_output = od_pipe(raw_image)
input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
labels_format = [b['label'] for b in pipeline_output]
# Process bounding boxes
for b, l in zip(input_boxes_format, labels_format):
st.subheader(f'bounding box : {l}')
predicted_mask = process_image(raw_image, model, processor, bounding_box=b)
if predicted_mask is not None:
display_masked_images(raw_image, predicted_mask, "Masked Image")
# Process input points
for (x, y) in XS_YS:
point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y
st.subheader(f"Input points : ({1/x},{1/y})")
predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y])
if predicted_mask is not None:
display_masked_images(raw_image, predicted_mask, "Masked Image")
if __name__ == "__main__":
main()