File size: 1,986 Bytes
56a9e12
e657276
56a9e12
e657276
56a9e12
896437a
56a9e12
 
896437a
56a9e12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896437a
56a9e12
e657276
56a9e12
 
e657276
6984480
56a9e12
6984480
56a9e12
896437a
56a9e12
 
 
 
 
 
 
 
 
e657276
9112e74
 
56a9e12
e657276
9112e74
56a9e12
c396ac7
e657276
 
 
 
 
9112e74
e657276
56a9e12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
import gradio as gr
from PIL import Image
from process import load_seg_model, get_palette, generate_mask

# Device selection
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_model():
    """Load model with Hugging Face Spaces compatible paths"""
    model_dir = 'model'
    checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth')
    
    # Verify model exists (must be pre-uploaded to HF Spaces)
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"Model not found at {checkpoint_path}. "
            "Please upload the model file to your Space's repository."
        )
    
    try:
        net = load_seg_model(checkpoint_path, device=device)
        palette = get_palette(4)
        return net, palette
    except Exception as e:
        raise RuntimeError(f"Model loading failed: {str(e)}")

# Initialize model (will fail fast if there's an issue)
net, palette = load_model()

def process_image(img: Image.Image) -> Image.Image:
    """Process input image and return segmentation mask"""
    if img is None:
        raise gr.Error("Please upload an image first")
    
    try:
        return generate_mask(img, net=net, palette=palette, device=device)
    except Exception as e:
        raise gr.Error(f"Processing failed: {str(e)}")

# Gradio interface
title = "Cloth Segmentation Demo"
description = """
Upload an image to get cloth segmentation using U2NET.
"""

with gr.Blocks() as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            submit_btn = gr.Button("Process", variant="primary")
        with gr.Column():
            output_image = gr.Image(type="pil", label="Segmentation Result")
    
    submit_btn.click(
        fn=process_image,
        inputs=input_image,
        outputs=output_image
    )

if __name__ == "__main__":
    demo.launch()