|
import gradio as gr |
|
from transformers import pipeline |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") |
|
|
|
|
|
url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80" |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image.save("example_image.jpg") |
|
|
|
|
|
label_dict = { |
|
0: "Background", |
|
1: "Hat", |
|
2: "Hair", |
|
3: "Sunglasses", |
|
4: "Upper-clothes", |
|
5: "Skirt", |
|
6: "Pants", |
|
7: "Dress", |
|
8: "Belt", |
|
9: "Left-shoe", |
|
10: "Right-shoe", |
|
11: "Face", |
|
12: "Left-leg", |
|
13: "Right-leg", |
|
14: "Left-arm", |
|
15: "Right-arm", |
|
16: "Bag", |
|
17: "Scarf", |
|
} |
|
|
|
|
|
def segment_image(image): |
|
|
|
result = pipe(image) |
|
|
|
|
|
image_width, image_height = result[0]["mask"].size |
|
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) |
|
|
|
|
|
for idx, entry in enumerate(result): |
|
mask = np.array(entry["mask"]) |
|
segmentation_map[mask > 0] = idx |
|
|
|
|
|
plt.figure(figsize=(8, 8)) |
|
plt.imshow(segmentation_map, cmap="tab20") |
|
cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes") |
|
cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))]) |
|
plt.title("Combined Segmentation Map") |
|
plt.axis("off") |
|
|
|
|
|
plt.savefig("segmented_output.png", bbox_inches="tight") |
|
plt.close() |
|
return Image.open("segmented_output.png") |
|
|
|
|
|
interface = gr.Interface( |
|
fn=segment_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
|
|
examples=["1.jpg", "2.jpg", "3.jpg"], |
|
title="Image Segmentation with Colormap", |
|
description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." |
|
) |
|
|
|
|
|
interface.launch() |
|
|