fadzwan's picture
Update app.py
2411469 verified
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import gradio as gr
import numpy as np
from PIL import Image
# Load the model and feature extractor
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes")
def predict(image):
# Prepare the image for the model
inputs = feature_extractor(images=image, return_tensors="pt")
# Get model outputs
outputs = model(**inputs)
# Get the segmentation logits
logits = outputs.logits
# Apply softmax to get probabilities
probabilities = logits.softmax(dim=1) # shape: (batch_size, num_classes, height, width)
# Get the predicted class for each pixel
predicted_class = probabilities.argmax(dim=1).squeeze().cpu().numpy() # shape: (height, width)
# Create a color map (you can define your own color mapping for different classes)
color_map = np.array([[0, 0, 0], # Class 0 - background
[255, 0, 0], # Class 1 - red
[0, 255, 0], # Class 2 - green
[0, 0, 255]]) # Class 3 - blue
# Create an output mask image
mask_image = color_map[predicted_class] # Map class indices to colors
mask_image = Image.fromarray(mask_image.astype('uint8')) # Convert to PIL Image
return mask_image
def segmentation_interface(image):
return predict(image)
# Create a Gradio interface for image segmentation
gr.Interface(fn=segmentation_interface, inputs="image", outputs="image").launch()