Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Load the segmentation pipeline | |
| pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") | |
| # Save the example image locally | |
| url = "https://plus.unsplash.com/premium_photo-1673210886161-bfcc40f54d1f?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8cGVyc29uJTIwc3RhbmRpbmd8ZW58MHx8MHx8&w=1000&q=80" | |
| image = Image.open(requests.get(url, stream=True).raw) | |
| image.save("example_image.jpg") # Save the image locally | |
| # Your predefined label dictionary | |
| label_dict = { | |
| 0: "Background", | |
| 1: "Hat", | |
| 2: "Hair", | |
| 3: "Sunglasses", | |
| 4: "Upper-clothes", | |
| 5: "Skirt", | |
| 6: "Pants", | |
| 7: "Dress", | |
| 8: "Belt", | |
| 9: "Left-shoe", | |
| 10: "Right-shoe", | |
| 11: "Face", | |
| 12: "Left-leg", | |
| 13: "Right-leg", | |
| 14: "Left-arm", | |
| 15: "Right-arm", | |
| 16: "Bag", | |
| 17: "Scarf", | |
| } | |
| # Function to process the image and generate the segmentation map | |
| def segment_image(image): | |
| # Perform segmentation | |
| result = pipe(image) | |
| # Initialize an empty array for the segmentation map | |
| image_width, image_height = result[0]["mask"].size | |
| segmentation_map = np.zeros((image_height, image_width), dtype=np.uint8) | |
| # Combine masks into a single segmentation map | |
| for idx, entry in enumerate(result): | |
| mask = np.array(entry["mask"]) # Convert the PIL mask to a NumPy array | |
| segmentation_map[mask > 0] = idx # Assign the class index | |
| # Create a matplotlib figure and visualize the segmentation map | |
| plt.figure(figsize=(8, 8)) | |
| plt.imshow(segmentation_map, cmap="tab20") # Visualize using a colormap | |
| cbar = plt.colorbar(ticks=range(len(label_dict)), label="Classes") | |
| cbar.ax.set_yticklabels([label_dict[i] for i in range(len(label_dict))]) | |
| plt.title("Combined Segmentation Map") | |
| plt.axis("off") | |
| # Save the figure as a PIL image for Gradio | |
| plt.savefig("segmented_output.png", bbox_inches="tight") # Save as a temporary file | |
| plt.close() # Close the figure to free memory | |
| return Image.open("segmented_output.png") | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=segment_image, | |
| inputs=gr.Image(type="pil"), # Input is an image | |
| outputs=gr.Image(type="pil"), # Output is an image with the colormap | |
| #examples=["example_image.jpg"], # Use the saved image as an example | |
| examples=["1.jpg", "2.jpg", "3.jpg"], | |
| title="Image Segmentation with Colormap", | |
| description="Upload an image, and the segmentation model will produce an output with a colormap applied to the segmented classes." | |
| ) | |
| # Launch the app | |
| interface.launch() | |