File size: 1,912 Bytes
896437a
 
 
1c6ea49
896437a
 
6984480
896437a
c396ac7
 
 
 
 
 
 
 
 
896437a
 
6984480
c396ac7
6984480
c396ac7
6984480
 
896437a
c396ac7
1369068
c396ac7
6984480
c396ac7
 
 
 
 
1c6ea49
c396ac7
 
9112e74
 
c396ac7
9112e74
c396ac7
 
9112e74
 
541e07e
c396ac7
 
 
 
 
1369068
c396ac7
 
 
9112e74
c396ac7
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Initialize models
try:
    checkpoint_path = 'model/cloth_segm.pth'
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
    net = load_seg_model(checkpoint_path, device=device)
    palette = get_palette(4)
except Exception as e:
    raise RuntimeError(f"Failed to initialize models: {str(e)}")

def run(img):
    if img is None:
        raise gr.Error("Please upload an image first")
    try:
        return generate_mask(img, net=net, palette=palette, device=device)
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# Handle examples
image_dir = 'input'
examples = []
if os.path.exists(image_dir):
    examples = [
        [os.path.join(image_dir, f)] 
        for f in sorted(os.listdir(image_dir)) 
        if f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ]

# Create interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil")
        with gr.Column():
            output_image = gr.Image(label="Segmentation Result")
    
    with gr.Row():
        gr.Examples(
            examples=examples,
            inputs=[input_image],
            outputs=[output_image],
            fn=run,
            cache_examples=True,
            label="Example Images"
        )
    
    submit_btn = gr.Button("Segment", variant="primary")
    submit_btn.click(fn=run, inputs=input_image, outputs=output_image)

# Launch with appropriate settings
try:
    demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
    print(f"Error launching app: {str(e)}")
    # Fallback with sharing enabled
    demo.launch(share=True)