File size: 2,411 Bytes
56a9e12
6648f32
 
 
 
56a9e12
77304c1
6648f32
77304c1
6648f32
77304c1
 
 
 
 
6648f32
77304c1
 
 
6648f32
77304c1
6648f32
77304c1
 
 
 
 
 
 
6648f32
77304c1
6648f32
77304c1
6648f32
77304c1
 
 
 
6648f32
 
 
77304c1
 
 
 
 
 
 
6648f32
 
77304c1
 
 
 
6648f32
77304c1
 
 
 
 
 
 
 
 
6648f32
77304c1
 
 
 
 
 
6648f32
77304c1
 
 
 
 
 
6648f32
77304c1
84e526f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os

# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "model/cloth_segm.pth"

try:
    net = load_seg_model(model_path, device=device)
    palette = get_palette(4)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {str(e)}")

def process_image(input_img):
    """Process input image and return segmentation mask"""
    if input_img is None:
        raise gr.Error("Please upload or capture an image first")
    
    try:
        # Convert to PIL Image if it's not already
        if not isinstance(input_img, Image.Image):
            input_img = Image.fromarray(input_img)
        
        # Generate mask
        output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
        return output_mask
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# Create simple interface
with gr.Blocks(title="Cloth Segmentation") as demo:
    gr.Markdown("""
    # 🧥 Cloth Segmentation App
    Upload an image or capture from your camera to get segmentation results.
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                sources=["upload", "webcam"],
                type="pil",
                label="Input Image",
                interactive=True
            )
            submit_btn = gr.Button("Process", variant="primary")
        
        with gr.Column():
            output_image = gr.Image(
                label="Segmentation Result",
                interactive=False
            )
    
    # Examples section (optional)
    example_dir = "input"
    if os.path.exists(example_dir):
        example_images = [
            os.path.join(example_dir, f)
            for f in os.listdir(example_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]
        
        gr.Examples(
            examples=example_images,
            inputs=[input_image],
            outputs=[output_image],
            fn=process_image,
            cache_examples=True,
            label="Example Images"
        )
    
    submit_btn.click(
        fn=process_image,
        inputs=input_image,
        outputs=output_image
    )

# Launch with appropriate settings
if __name__ == "__main__":
    demo.launch()