File size: 2,173 Bytes
c0827b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

def process_mask(file, category_to_hide):
    # Load the .npy file
    data = np.load(file.name)
    
    # Define grouped categories
    grouped_mapping = {
        "Background": [0],
        "Clothes": [1, 12, 22, 8, 9, 17, 18],  # Includes Shoes, Socks, Slippers
        "Face": [2, 23, 24, 25, 26, 27],  # Face Neck, Lips, Teeth, Tongue
        "Hair": [3],  # Hair
        "Skin (Hands, Feet, Body)": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21]  # Hands, Feet, Arms, Legs, Torso
    }

    # Assign colors for the categories
    group_colors = {
        "Background": "black",
        "Clothes": "magenta",
        "Face": "orange",
        "Hair": "brown",
        "Skin (Hands, Feet, Body)": "cyan"
    }
    
    # Create a new mask with grouped categories
    grouped_mask = np.zeros((*data.shape, 3), dtype=np.uint8)

    for category, indices in grouped_mapping.items():
        if category == category_to_hide:
            continue  # Skip applying colors for the selected category to hide
        for idx in indices:
            mask = data == idx
            rgb = mcolors.to_rgb(group_colors[category])  # Convert color to RGB
            grouped_mask[mask] = [int(c * 255) for c in rgb]

    # Save the mask image
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(grouped_mask)
    ax.axis("off")
    plt.tight_layout()
    
    # Save to file for Gradio output
    output_path = "output_mask.png"
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
    plt.close()
    
    return output_path

# Define Gradio Interface
demo = gr.Interface(
    fn=process_mask,
    inputs=[
        gr.File(label="Upload .npy Segmentation File"),
        gr.Radio([
            "Background", "Clothes", "Face", "Hair", "Skin (Hands, Feet, Body)"
        ], label="Select Category to Hide")
    ],
    outputs=gr.Image(label="Modified Segmentation Mask"),
    title="Segmentation Mask Editor",
    description="Upload a .npy segmentation file and select a category to mask (hide with black)."
)

if __name__ == "__main__":
    demo.launch()