File size: 1,579 Bytes
429f61d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import gradio as gr
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import segmentation_models_pytorch as smp
def load_model(model_type):
# Model loading simplified for clarity
model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth")
model.to(device='cuda')
return SamAutomaticMaskGenerator(model)
def segment_and_classify(image, model_type):
model = load_model(model_type)
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Generate masks
masks = model.generate(image_cv)
# Prepare to store segments
segments = []
# Loop through masks and extract segments
for mask_data in masks:
mask = mask_data['segmentation']
segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3]) # Apply mask to the image
segments.append(segment) # Store the segment for classification
# Here you would call the classification model (e.g., CLIP)
# For now, let's just return the first segment for visualization
return Image.fromarray(segments[0])
iface = gr.Interface(
fn=segment_and_classify,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
outputs=gr.outputs.Image(type="pil"),
title="SAM Model Segmentation and Classification",
description="Upload an image, select a model type, and receive the segmented and classified parts."
)
iface.launch() |