Walid-Ahmed's picture
Rename app.y to app.py
b1fbd6e verified
raw
history blame
2.65 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# Load the segmentation pipeline
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
# Save the example image locally
url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80"
image = Image.open(requests.get(url, stream=True).raw)
image.save("example_image.jpg") # Save the image locally
# Your predefined label dictionary
label_dict = {
0: "Background",
1: "Hat",
2: "Hair",
3: "Sunglasses",
4: "Upper-clothes",
5: "Skirt",
6: "Pants",
7: "Dress",
8: "Belt",
9: "Left-shoe",
10: "Right-shoe",
11: "Face",
12: "Left-leg",
13: "Right-leg",
14: "Left-arm",
15: "Right-arm",
16: "Bag",
17: "Scarf",
}
# Function to process the image and generate the segmentation map
def segment_image(image):
# Perform segmentation
result = pipe(image)
# Initialize an empty array for the segmentation map
image_width, image_height = result[0]["mask"].size
segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8)
# Combine masks into a single segmentation map
for idx, entry in enumerate(result):
mask = np.array(entry["mask"]) # Convert the PIL mask to a NumPy array
segmentation_map[mask > 0] = idx # Assign the class index
# Create a matplotlib figure and visualize the segmentation map
plt.figure(figsize=(8, 8))
plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap
cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes")
cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))])
plt.title("Combined Segmentation Map")
plt.axis("off")
# Save the figure as a PIL image for Gradio
plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file
plt.close() # Close the figure to free memory
return Image.open("segmented_output.png")
# Gradio interface
interface = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"), # Input is an image
outputs=gr.Image(type="pil"), # Output is an image with the colormap
examples=["example_image.jpg"], # Use the saved image as an example
title="Image Segmentation with Colormap",
description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes."
)
# Launch the app
interface.launch()