File size: 3,792 Bytes
0038d2b
6648f32
0038d2b
6648f32
0038d2b
56a9e12
0038d2b
6648f32
0038d2b
 
 
 
 
 
 
 
 
 
 
6648f32
0038d2b
 
6648f32
0038d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6648f32
77304c1
6648f32
0038d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6648f32
 
0038d2b
 
6648f32
0038d2b
 
 
6648f32
0038d2b
 
 
 
6648f32
0038d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77304c1
6648f32
0038d2b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def read_content(file_path: str) -> str:
    """Read file content with error handling"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()
    except FileNotFoundError:
        print(f"Warning: File {file_path} not found")
        return ""
    except Exception as e:
        print(f"Error reading file {file_path}: {str(e)}")
        return ""

def initialize_and_load_models():
    """Initialize and load models with error handling"""
    try:
        checkpoint_path = 'model/cloth_segm.pth'
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
        return load_seg_model(checkpoint_path, device=device)
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

net = initialize_and_load_models()
if net is None:
    raise RuntimeError("Failed to load model - check logs for details")

palette = get_palette(4)

def run(img):
    """Process image with error handling"""
    if img is None:
        raise gr.Error("No image uploaded")
    try:
        cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
        if cloth_seg is None:
            raise gr.Error("Failed to generate mask")
        return cloth_seg
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# CSS styling
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
'''

# Collect example images
image_dir = 'input'
image_list = []
if os.path.exists(image_dir):
    image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
    image_list.sort()
examples = [[img] for img in image_list]

with gr.Blocks(css=css) as demo:
    gr.HTML(read_content("header.html"))

    with gr.Row():
        with gr.Column():
            image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")

        with gr.Column():
            image_out = gr.Image(label="Output", elem_id="output-img")

    with gr.Row():
        gr.Examples(
            examples=examples,
            inputs=[image],
            label="Examples - Input Images",
            examples_per_page=12
        )
        btn = gr.Button("Run!", variant="primary")

    btn.click(fn=run, inputs=[image], outputs=[image_out])

    gr.HTML(
        """
        <div class="footer">
            <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
        </div>
        <div class="acknowledgments">
            <p><h4>ACKNOWLEDGEMENTS</h4></p>
            <p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
            <p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
        </div>
        """
    )

# For Hugging Face Spaces, use launch() without share=True
demo.launch()