File size: 1,649 Bytes
8c06730
97cb2c8
2411469
7cf1bb2
97cb2c8
7cf1bb2
8c06730
 
 
 
2411469
8c06730
2411469
8c06730
2411469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c06730
 
 
 
 
2411469
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import gradio as gr
import numpy as np
from PIL import Image

# Load the model and feature extractor
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")

def predict(image):
    # Prepare the image for the model
    inputs = feature_extractor(images=image, return_tensors="pt")
    # Get model outputs
    outputs = model(**inputs)
    
    # Get the segmentation logits
    logits = outputs.logits
    # Apply softmax to get probabilities
    probabilities = logits.softmax(dim=1)  # shape: (batch_size, num_classes, height, width)
    
    # Get the predicted class for each pixel
    predicted_class = probabilities.argmax(dim=1).squeeze().cpu().numpy()  # shape: (height, width)
    
    # Create a color map (you can define your own color mapping for different classes)
    color_map = np.array([[0, 0, 0],    # Class 0 - background
                          [255, 0, 0],  # Class 1 - red
                          [0, 255, 0],  # Class 2 - green
                          [0, 0, 255]]) # Class 3 - blue

    # Create an output mask image
    mask_image = color_map[predicted_class]  # Map class indices to colors
    mask_image = Image.fromarray(mask_image.astype('uint8'))  # Convert to PIL Image
    
    return mask_image

def segmentation_interface(image):
    return predict(image)

# Create a Gradio interface for image segmentation
gr.Interface(fn=segmentation_interface, inputs="image", outputs="image").launch()