File size: 1,986 Bytes
56a9e12 e657276 56a9e12 e657276 56a9e12 896437a 56a9e12 896437a 56a9e12 896437a 56a9e12 e657276 56a9e12 e657276 6984480 56a9e12 6984480 56a9e12 896437a 56a9e12 e657276 9112e74 56a9e12 e657276 9112e74 56a9e12 c396ac7 e657276 9112e74 e657276 56a9e12 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import torch
import gradio as gr
from PIL import Image
from process import load_seg_model, get_palette, generate_mask
# Device selection
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
"""Load model with Hugging Face Spaces compatible paths"""
model_dir = 'model'
checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth')
# Verify model exists (must be pre-uploaded to HF Spaces)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Model not found at {checkpoint_path}. "
"Please upload the model file to your Space's repository."
)
try:
net = load_seg_model(checkpoint_path, device=device)
palette = get_palette(4)
return net, palette
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# Initialize model (will fail fast if there's an issue)
net, palette = load_model()
def process_image(img: Image.Image) -> Image.Image:
"""Process input image and return segmentation mask"""
if img is None:
raise gr.Error("Please upload an image first")
try:
return generate_mask(img, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Processing failed: {str(e)}")
# Gradio interface
title = "Cloth Segmentation Demo"
description = """
Upload an image to get cloth segmentation using U2NET.
"""
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Segmentation Result")
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_image
)
if __name__ == "__main__":
demo.launch() |