import gradio as gr from transformers import pipeline from PIL import Image import numpy as np import matplotlib.pyplot as plt # Load the segmentation pipeline pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") # Your predefined label dictionary label_dict = { 0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", 4: "Upper-clothes", 5: "Skirt", 6: "Pants", 7: "Dress", 8: "Belt", 9: "Left-shoe", 10: "Right-shoe", 11: "Face", 12: "Left-leg", 13: "Right-leg", 14: "Left-arm", 15: "Right-arm", 16: "Bag", 17: "Scarf", } # Function to process the image and generate the segmentation map # Function to process the image and generate the segmentation map def segment_image(image): # Perform segmentation result = pipe(image) # Initialize an empty array for the segmentation map image_width, image_height = result[0]["mask"].size segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) # Combine masks into a single segmentation map for entry in result: label = entry["label"] # Get the label (e.g., "Hair", "Upper-clothes") mask = np.array(entry["mask"]) # Convert PIL Image to NumPy array # Find the index of the label in the original label dictionary class_idx = [key for key, value in label_dict.items() if value == label][0] # Assign the correct class index to the segmentation map segmentation_map[mask > 0] = class_idx # Get the unique class indices in the segmentation map unique_classes = np.unique(segmentation_map) # Print the names of the detected classes print("Detected Classes:") for class_idx in unique_classes: print(f"- {label_dict[class_idx]}") # Create a matplotlib figure and visualize the segmentation map plt.figure(figsize=(8, 8)) plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap # Get the unique class indices in the segmentation map unique_classes = np.unique(segmentation_map) # Filter the label dictionary to include only detected classes filtered_labels = {idx: label_dict[idx] for idx in unique_classes} # Create a dynamic colorbar with only the detected classes cbar = plt.colorbar(ticks=unique_classes) cbar.ax.set_yticklabels([filtered_labels[i] for i in unique_classes]) plt.title("Segmented Image with Detected Classes") plt.axis("off") plt.savefig("segmented_output.png", bbox_inches="tight") plt.close() return Image.open("segmented_output.png") # Gradio interface interface = gr.Interface( fn=segment_image, inputs=gr.Image(type="pil"), # Input is an image outputs=gr.Image(type="pil"), # Output is an image with the colormap #examples=["example_image.jpg"], # Use the saved image as an example examples=["1.jpg", "2.jpg", "3.jpg"], title="Clothes Segmentation with Colormap", description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." ) # Launch the app interface.launch()