|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import matplotlib.colors as mcolors |
|
|
|
def process_mask(file, category_to_hide): |
|
|
|
data = np.load(file.name) |
|
|
|
|
|
grouped_mapping = { |
|
"Background": [0], |
|
"Clothes": [1, 12, 22, 8, 9, 17, 18], |
|
"Face": [2, 23, 24, 25, 26, 27], |
|
"Hair": [3], |
|
"Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] |
|
} |
|
|
|
|
|
group_colors = { |
|
"Background": "black", |
|
"Clothes": "magenta", |
|
"Face": "orange", |
|
"Hair": "brown", |
|
"Skin (Hands, Feet, Body)": "cyan" |
|
} |
|
|
|
|
|
grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8) |
|
|
|
for category, indices in grouped_mapping.items(): |
|
if category == category_to_hide: |
|
continue |
|
for idx in indices: |
|
mask = data == idx |
|
rgb = mcolors.to_rgb(group_colors[category]) |
|
grouped_mask[mask] = [int(c * 255) for c in rgb] |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(6, 6)) |
|
ax.imshow(grouped_mask) |
|
ax.axis("off") |
|
plt.tight_layout() |
|
|
|
|
|
output_path = "output_mask.png" |
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
|
|
return output_path |
|
|
|
|
|
demo = gr.Interface( |
|
fn=process_mask, |
|
inputs=[ |
|
gr.File(label="Upload .npy Segmentation File"), |
|
gr.Radio([ |
|
"Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)" |
|
], label="Select Category to Hide") |
|
], |
|
outputs=gr.Image(label="Modified Segmentation Mask"), |
|
title="Segmentation Mask Editor", |
|
description="Upload a .npy segmentation file and select a category to mask (hide with black)." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|