import gradio as gr from transformers import pipeline from PIL import Image import numpy as np import matplotlib.pyplot as plt # Load the segmentation pipeline pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") # Save the example image locally url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80" image = Image.open(requests.get(url, stream=True).raw) image.save("example_image.jpg") # Save the image locally # Your predefined label dictionary label_dict = { 0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", 4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress", 8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face", 12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm", 16: "Bag", 17: "Scarf", } # Function to process the image and generate the segmentation map def segment_image(image): # Perform segmentation result = pipe(image) # Initialize an empty array for the segmentation map image_width, image_height = result[0]["mask"].size segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) # Combine masks into a single segmentation map for idx, entry in enumerate(result): mask = np.array(entry["mask"]) # Convert the PIL mask to a NumPy array segmentation_map[mask > 0] = idx # Assign the class index # Create a matplotlib figure and visualize the segmentation map plt.figure(figsize=(8, 8)) plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes") cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))]) plt.title("Combined Segmentation Map") plt.axis("off") # Save the figure as a PIL image for Gradio plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file plt.close() # Close the figure to free memory return Image.open("segmented_output.png") # Gradio interface interface = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil"), # Input is an image outputs=gr.Image(type="pil"), # Output is an image with the colormap #examples=["example_image.jpg"], # Use the saved image as an example examples=["1.jpg", "2.jpg", "3.jpg"] title="Image Segmentation with Colormap", description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." ) # Launch the app interface.launch()