File size: 2,173 Bytes
c0827b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def process_mask(file, category_to_hide):
# Load the .npy file
data = np.load(file.name)
# Define grouped categories
grouped_mapping = {
"Background": [0],
"Clothes": [1, 12, 22, 8, 9, 17, 18], # Includes Shoes, Socks, Slippers
"Face": [2, 23, 24, 25, 26, 27], # Face Neck, Lips, Teeth, Tongue
"Hair": [3], # Hair
"Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso
}
# Assign colors for the categories
group_colors = {
"Background": "black",
"Clothes": "magenta",
"Face": "orange",
"Hair": "brown",
"Skin (Hands, Feet, Body)": "cyan"
}
# Create a new mask with grouped categories
grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8)
for category, indices in grouped_mapping.items():
if category == category_to_hide:
continue # Skip applying colors for the selected category to hide
for idx in indices:
mask = data == idx
rgb = mcolors.to_rgb(group_colors[category]) # Convert color to RGB
grouped_mask[mask] = [int(c * 255) for c in rgb]
# Save the mask image
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(grouped_mask)
ax.axis("off")
plt.tight_layout()
# Save to file for Gradio output
output_path = "output_mask.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path
# Define Gradio Interface
demo = gr.Interface(
fn=process_mask,
inputs=[
gr.File(label="Upload .npy Segmentation File"),
gr.Radio([
"Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
], label="Select Category to Hide")
],
outputs=gr.Image(label="Modified Segmentation Mask"),
title="Segmentation Mask Editor",
description="Upload a .npy segmentation file and select a category to mask (hide with black)."
)
if __name__ == "__main__":
demo.launch()
|