|
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
feature_extractor = SegformerFeatureExtractor.from_pretrained("mattmdjaga/segformer_b2_clothes") |
|
|
|
def predict(image): |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
|
|
probabilities = logits.softmax(dim=1) |
|
|
|
|
|
predicted_class = probabilities.argmax(dim=1).squeeze().cpu().numpy() |
|
|
|
|
|
color_map = np.array([[0, 0, 0], |
|
[255, 0, 0], |
|
[0, 255, 0], |
|
[0, 0, 255]]) |
|
|
|
|
|
mask_image = color_map[predicted_class] |
|
mask_image = Image.fromarray(mask_image.astype('uint8')) |
|
|
|
return mask_image |
|
|
|
def segmentation_interface(image): |
|
return predict(image) |
|
|
|
|
|
gr.Interface(fn=segmentation_interface, inputs="image", outputs="image").launch() |
|
|