File size: 2,411 Bytes
c0827b5 d9dc7e6 c0827b5 d9dc7e6 38efd19 d9dc7e6 6ca5592 d9dc7e6 5282488 d9dc7e6 c0827b5 38efd19 c0827b5 5282488 c0827b5 5282488 6ca5592 5282488 c0827b5 d9dc7e6 5282488 c0827b5 d9dc7e6 c0827b5 d9dc7e6 c0827b5 d9dc7e6 6ca5592 38efd19 5282488 c0827b5 5282488 c0827b5 5282488 c0827b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from gradio_client import Client, handle_file
from PIL import Image
import requests
from io import BytesIO
def get_segmentation_mask(image_url):
client = Client("facebook/sapiens-seg")
result = client.predict(image=handle_file(image_url), model_name="1b", api_name="/process_image")
return np.load(result[1]) # Result[1] contains the .npy mask
def process_image(image, categories_to_hide):
# Convert uploaded image to a PIL Image
image = Image.open(image.name).convert("RGBA")
# Save temporarily and get the mask
image.save("temp_image.png")
mask_data = get_segmentation_mask("temp_image.png")
# Define grouped categories
grouped_mapping = {
"Background": [0],
"Clothes": [1, 12, 22, 8, 9, 17, 18], # Includes Shoes, Socks, Slippers
"Face": [2, 23, 24, 25, 26, 27], # Face Neck, Lips, Teeth, Tongue
"Hair": [3], # Hair
"Skin": [4, 5, 6, 7, 10, 11, 13, 14, 15, 16, 19, 20, 21] # Hands, Feet, Arms, Legs, Torso
}
# Convert image to numpy array (RGBA)
image_array = np.array(image, dtype=np.uint8)
# Create an empty transparent image
transparent_image = np.zeros_like(image_array, dtype=np.uint8)
# Preserve only the selected mask regions, make everything else transparent
mask_combined = np.zeros_like(mask_data, dtype=bool)
for category in categories_to_hide:
for idx in grouped_mapping.get(category, []):
mask_combined |= (mask_data == idx)
# Apply the mask (preserve only selected regions)
transparent_image[mask_combined] = image_array[mask_combined]
# Convert back to PIL Image
result_image = Image.fromarray(transparent_image, mode="RGBA")
return result_image
# Define Gradio Interface
demo = gr.Interface(
fn=process_image,
inputs=[
gr.File(label="Upload an Image"),
gr.CheckboxGroup([
"Background", "Clothes", "Face", "Hair", "Skin"
], label="Select Categories to Preserve")
],
outputs=gr.Image(label="Masked Image", type="pil"),
title="Segmentation Mask Editor",
description="Upload an image, generate a segmentation mask, and select categories to preserve while making the rest transparent."
)
if __name__ == "__main__":
demo.launch()
|