File size: 1,579 Bytes
429f61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import gradio as gr
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
import segmentation_models_pytorch as smp

def load_model(model_type):
    # Model loading simplified for clarity
    model = sam_model_registry[model_type](checkpoint=f"sam_{model_type}_checkpoint.pth")
    model.to(device='cuda')
    return SamAutomaticMaskGenerator(model)

def segment_and_classify(image, model_type):
    model = load_model(model_type)
    image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    
    # Generate masks
    masks = model.generate(image_cv)
    
    # Prepare to store segments
    segments = []
    
    # Loop through masks and extract segments
    for mask_data in masks:
        mask = mask_data['segmentation']
        segment = image_cv * np.tile(mask[:, :, None], [1, 1, 3])  # Apply mask to the image
        segments.append(segment)  # Store the segment for classification
    
    # Here you would call the classification model (e.g., CLIP)
    # For now, let's just return the first segment for visualization
    return Image.fromarray(segments[0])

iface = gr.Interface(
    fn=segment_and_classify,
    inputs=[gr.inputs.Image(type="pil"), gr.inputs.Dropdown(['vit_h', 'vit_b', 'vit_l'], label="Model Type")],
    outputs=gr.outputs.Image(type="pil"),
    title="SAM Model Segmentation and Classification",
    description="Upload an image, select a model type, and receive the segmented and classified parts."
)

iface.launch()