# Part of the source code is in: fashn-ai/sapiens-body-part-segmentation import os import gradio as gr import numpy as np import spaces import torch from gradio.themes.utils import sizes from PIL import Image from torchvision import transforms from utils.vis_utils import get_palette, visualize_mask_with_overlay if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets") CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints") CHECKPOINTS = { "0.3B": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2", "0.6B": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2", "1B": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2", "2B": "sapiens_2b_goliath_best_goliath_mIoU_8179_epoch_181_torchscript.pt2", } def load_model(checkpoint_name: str): checkpoint_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS[checkpoint_name]) model = torch.jit.load(checkpoint_path) model.eval() model.to("cuda") return model MODELS = {name: load_model(name) for name in CHECKPOINTS.keys()} @torch.inference_mode() def run_model(model, input_tensor, height, width): output = model(input_tensor) output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False) _, preds = torch.max(output, 1) return preds transform_fn = transforms.Compose( [ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) # ----------------- CORE FUNCTION ----------------- # @spaces.GPU def segment(image: Image.Image, model_name: str) -> Image.Image: input_tensor = transform_fn(image).unsqueeze(0).to("cuda") model = MODELS[model_name] preds = run_model(model, input_tensor, height=image.height, width=image.width) mask = preds.squeeze(0).cpu().numpy() mask_image = Image.fromarray(mask.astype("uint8")) blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5) return blended_image # ----------------- GRADIO UI ----------------- # with open("banner.html", "r") as file: banner = file.read() with open("tips.html", "r") as file: tips = file.read() CUSTOM_CSS = """ .image-container img { max-width: 512px; max-height: 512px; margin: 0 auto; border-radius: 0px; .gradio-container {background-color: #fafafa} """ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo: gr.HTML(banner) gr.HTML(tips) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil", format="png") model_name = gr.Dropdown( label="Model Version", choices=list(CHECKPOINTS.keys()), value="0.3B", ) example_model = gr.Examples( inputs=input_image, examples_per_page=10, examples=[ os.path.join(ASSETS_DIR, "examples", img) for img in os.listdir(os.path.join(ASSETS_DIR, "examples")) ], ) with gr.Column(): result_image = gr.Image(label="Segmentation Result", format="png") run_button = gr.Button("Run") gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath") run_button.click( fn=segment, inputs=[input_image, model_name], outputs=[result_image], ) if __name__ == "__main__": demo.launch(share=False)