File size: 1,649 Bytes
			
			| 8c06730 97cb2c8 2411469 7cf1bb2 97cb2c8 7cf1bb2 8c06730 2411469 8c06730 2411469 8c06730 2411469 8c06730 2411469 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import gradio as gr
import numpy as np
from PIL import Image
# Load the model and feature extractor
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
def predict(image):
    # Prepare the image for the model
    inputs = feature_extractor(images=image, return_tensors="pt")
    # Get model outputs
    outputs = model(**inputs)
    
    # Get the segmentation logits
    logits = outputs.logits
    # Apply softmax to get probabilities
    probabilities = logits.softmax(dim=1)  # shape: (batch_size, num_classes, height, width)
    
    # Get the predicted class for each pixel
    predicted_class = probabilities.argmax(dim=1).squeeze().cpu().numpy()  # shape: (height, width)
    
    # Create a color map (you can define your own color mapping for different classes)
    color_map = np.array([[0, 0, 0],    # Class 0 - background
                          [255, 0, 0],  # Class 1 - red
                          [0, 255, 0],  # Class 2 - green
                          [0, 0, 255]]) # Class 3 - blue
    # Create an output mask image
    mask_image = color_map[predicted_class]  # Map class indices to colors
    mask_image = Image.fromarray(mask_image.astype('uint8'))  # Convert to PIL Image
    
    return mask_image
def segmentation_interface(image):
    return predict(image)
# Create a Gradio interface for image segmentation
gr.Interface(fn=segmentation_interface, inputs="image", outputs="image").launch()
 |