File size: 3,104 Bytes
c0827b5
 
 
 
d9dc7e6
 
 
 
2b6368a
c0827b5
d9dc7e6
 
 
38efd19
d9dc7e6
a12a7c1
6ca5592
d9dc7e6
5282488
d9dc7e6
a12a7c1
d9dc7e6
 
c0827b5
 
 
 
 
a12a7c1
c0827b5
38efd19
c0827b5
 
5282488
 
c0827b5
5282488
 
 
a12a7c1
5282488
a12a7c1
6ca5592
 
5282488
a12a7c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5282488
 
c0827b5
d9dc7e6
5282488
c0827b5
d9dc7e6
c0827b5
 
 
d9dc7e6
c0827b5
d9dc7e6
6ca5592
38efd19
5282488
c0827b5
5282488
c0827b5
5282488
c0827b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from gradio_client import Client, handle_file
from PIL import Image
import requests
from io import BytesIO
import cv2

def get_segmentation_mask(image_url):
    client = Client("facebook/sapiens-seg")
    result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
    return np.load(result[1])  # Result[1] contains the .npy mask


def process_image(image, categories_to_hide):
    # Convert uploaded image to a PIL Image
    image = Image.open(image.name).convert("RGBA")
    
    # Save temporarily and get the segmentation mask
    image.save("temp_image.png")
    mask_data = get_segmentation_mask("temp_image.png")
    
    # Define grouped categories
    grouped_mapping = {
        "Background": [0],
        "Clothes": [1, 12, 22, 8, 9, 17, 18],  # Includes Shoes, Socks, Slippers
        "Face": [2, 23, 24, 25, 26, 27],  # Face, Neck, Lips, Teeth, Tongue
        "Hair": [3],  # Hair
        "Skin": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21]  # Hands, Feet, Arms, Legs, Torso
    }

    # Convert image to numpy array (RGBA)
    image_array = np.array(image, dtype=np.uint8)
    
    # Create an empty transparent image
    transparent_image = np.zeros_like(image_array, dtype=np.uint8)
    
    # Create a binary mask for selected categories
    mask_combined = np.zeros_like(mask_data, dtype=bool)

    for category in categories_to_hide:
        for idx in grouped_mapping.get(category, []):
            mask_combined |= (mask_data == idx)

    # Expand clothing boundaries if clothes are in `categories_to_hide`
    if "Clothes" in categories_to_hide:
        clothing_mask = np.isin(mask_data, grouped_mapping["Clothes"]).astype(np.uint8)

        # Determine kernel size (5% of the smaller image dimension)
        height, width = clothing_mask.shape
        kernel_size = max(1, int(0.05 * min(height, width)))  # Ensure at least 1 pixel
        kernel = np.ones((kernel_size, kernel_size), np.uint8)

        # Dilate the clothing mask
        dilated_clothing_mask = cv2.dilate(clothing_mask, kernel, iterations=1)

        # Update mask_combined with the expanded clothing mask
        mask_combined |= (dilated_clothing_mask == 1)

    # Apply the mask (preserve only selected regions)
    transparent_image[mask_combined] = image_array[mask_combined]
    
    # Convert back to PIL Image
    result_image = Image.fromarray(transparent_image, mode="RGBA")
    
    return result_image

# Define Gradio Interface
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.File(label="Upload an Image"),
        gr.CheckboxGroup([
            "Background", "Clothes", "Face", "Hair", "Skin"
        ], label="Select Categories to Preserve")
    ],
    outputs=gr.Image(label="Masked Image", type="pil"),
    title="Segmentation Mask Editor",
    description="Upload an image, generate a segmentation mask, and select categories to preserve while making the rest transparent."
)

if __name__ == "__main__":
    demo.launch()