import gradio as gr import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors def process_mask(file, category_to_hide): # Load the .npy file data = np.load(file.name) # Define grouped categories grouped_mapping = { "Background": [0], "Clothes": [1, 12, 22, 8, 9, 17, 18], # Includes Shoes, Socks, Slippers "Face": [2, 23, 24, 25, 26, 27], # Face Neck, Lips, Teeth, Tongue "Hair": [3], # Hair "Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso } # Assign colors for the categories group_colors = { "Background": "black", "Clothes": "magenta", "Face": "orange", "Hair": "brown", "Skin (Hands, Feet, Body)": "cyan" } # Create a new mask with grouped categories grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8) for category, indices in grouped_mapping.items(): if category == category_to_hide: continue # Skip applying colors for the selected category to hide for idx in indices: mask = data == idx rgb = mcolors.to_rgb(group_colors[category]) # Convert color to RGB grouped_mask[mask] = [int(c * 255) for c in rgb] # Save the mask image fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(grouped_mask) ax.axis("off") plt.tight_layout() # Save to file for Gradio output output_path = "output_mask.png" plt.savefig(output_path, bbox_inches='tight', pad_inches=0) plt.close() return output_path # Define Gradio Interface demo = gr.Interface( fn=process_mask, inputs=[ gr.File(label="Upload .npy Segmentation File"), gr.Radio([ "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)" ], label="Select Category to Hide") ], outputs=gr.Image(label="Modified Segmentation Mask"), title="Segmentation Mask Editor", description="Upload a .npy segmentation file and select a category to mask (hide with black)." ) if __name__ == "__main__": demo.launch()