File size: 2,411 Bytes
c0827b5
 
 
 
d9dc7e6
 
 
 
c0827b5
d9dc7e6
 
 
38efd19
d9dc7e6
6ca5592
d9dc7e6
5282488
d9dc7e6
 
 
 
c0827b5
 
 
 
 
 
 
38efd19
c0827b5
 
5282488
 
c0827b5
5282488
 
 
 
 
6ca5592
 
5282488
 
 
 
c0827b5
d9dc7e6
5282488
c0827b5
d9dc7e6
c0827b5
 
 
d9dc7e6
c0827b5
d9dc7e6
6ca5592
38efd19
5282488
c0827b5
5282488
c0827b5
5282488
c0827b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from gradio_client import Client, handle_file
from PIL import Image
import requests
from io import BytesIO

def get_segmentation_mask(image_url):
    client = Client("facebook/sapiens-seg")
    result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
    return np.load(result[1])  # Result[1] contains the .npy mask

def process_image(image, categories_to_hide):
    # Convert uploaded image to a PIL Image
    image = Image.open(image.name).convert("RGBA")
    
    # Save temporarily and get the mask
    image.save("temp_image.png")
    mask_data = get_segmentation_mask("temp_image.png")
    
    # Define grouped categories
    grouped_mapping = {
        "Background": [0],
        "Clothes": [1, 12, 22, 8, 9, 17, 18],  # Includes Shoes, Socks, Slippers
        "Face": [2, 23, 24, 25, 26, 27],  # Face Neck, Lips, Teeth, Tongue
        "Hair": [3],  # Hair
        "Skin": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21]  # Hands, Feet, Arms, Legs, Torso
    }

    # Convert image to numpy array (RGBA)
    image_array = np.array(image, dtype=np.uint8)
    
    # Create an empty transparent image
    transparent_image = np.zeros_like(image_array, dtype=np.uint8)
    
    # Preserve only the selected mask regions, make everything else transparent
    mask_combined = np.zeros_like(mask_data, dtype=bool)
    for category in categories_to_hide:
        for idx in grouped_mapping.get(category, []):
            mask_combined |= (mask_data == idx)
    
    # Apply the mask (preserve only selected regions)
    transparent_image[mask_combined] = image_array[mask_combined]
    
    # Convert back to PIL Image
    result_image = Image.fromarray(transparent_image, mode="RGBA")
    
    return result_image

# Define Gradio Interface
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.File(label="Upload an Image"),
        gr.CheckboxGroup([
            "Background", "Clothes", "Face", "Hair", "Skin"
        ], label="Select Categories to Preserve")
    ],
    outputs=gr.Image(label="Masked Image", type="pil"),
    title="Segmentation Mask Editor",
    description="Upload an image, generate a segmentation mask, and select categories to preserve while making the rest transparent."
)

if __name__ == "__main__":
    demo.launch()