devendergarg14's picture
Create app.py
c0827b5 verified
raw
history blame
2.17 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def process_mask(file, category_to_hide):
# Load the .npy file
data = np.load(file.name)
# Define grouped categories
grouped_mapping = {
"Background": [0],
"Clothes": [1, 12, 22, 8, 9, 17, 18], # Includes Shoes, Socks, Slippers
"Face": [2, 23, 24, 25, 26, 27], # Face Neck, Lips, Teeth, Tongue
"Hair": [3], # Hair
"Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso
}
# Assign colors for the categories
group_colors = {
"Background": "black",
"Clothes": "magenta",
"Face": "orange",
"Hair": "brown",
"Skin (Hands, Feet, Body)": "cyan"
}
# Create a new mask with grouped categories
grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8)
for category, indices in grouped_mapping.items():
if category == category_to_hide:
continue # Skip applying colors for the selected category to hide
for idx in indices:
mask = data == idx
rgb = mcolors.to_rgb(group_colors[category]) # Convert color to RGB
grouped_mask[mask] = [int(c * 255) for c in rgb]
# Save the mask image
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(grouped_mask)
ax.axis("off")
plt.tight_layout()
# Save to file for Gradio output
output_path = "output_mask.png"
plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
plt.close()
return output_path
# Define Gradio Interface
demo = gr.Interface(
fn=process_mask,
inputs=[
gr.File(label="Upload .npy Segmentation File"),
gr.Radio([
"Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
], label="Select Category to Hide")
],
outputs=gr.Image(label="Modified Segmentation Mask"),
title="Segmentation Mask Editor",
description="Upload a .npy segmentation file and select a category to mask (hide with black)."
)
if __name__ == "__main__":
demo.launch()