File size: 3,792 Bytes
896437a
 
 
1c6ea49
896437a
 
6984480
896437a
1c6ea49
6984480
 
 
 
 
 
 
 
 
 
896437a
 
6984480
 
 
 
 
 
 
 
 
896437a
 
6984480
 
 
896437a
 
 
6984480
 
 
 
 
 
 
 
 
 
896437a
9112e74
1c6ea49
 
 
 
 
 
 
 
 
 
 
 
9112e74
1369068
6984480
 
 
 
541e07e
1c6ea49
1369068
1c6ea49
 
9112e74
 
 
1369068
9112e74
 
1369068
9112e74
 
541e07e
9112e74
 
 
1369068
6984480
1369068
9112e74
1369068
 
 
 
 
 
 
 
 
 
 
 
 
9112e74
6984480
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def read_content(file_path: str) -> str:
    """Read file content with error handling"""
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()
    except FileNotFoundError:
        print(f"Warning: File {file_path} not found")
        return ""
    except Exception as e:
        print(f"Error reading file {file_path}: {str(e)}")
        return ""

def initialize_and_load_models():
    """Initialize and load models with error handling"""
    try:
        checkpoint_path = 'model/cloth_segm.pth'
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
        return load_seg_model(checkpoint_path, device=device)
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None

net = initialize_and_load_models()
if net is None:
    raise RuntimeError("Failed to load model - check logs for details")

palette = get_palette(4)

def run(img):
    """Process image with error handling"""
    if img is None:
        raise gr.Error("No image uploaded")
    try:
        cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
        if cloth_seg is None:
            raise gr.Error("Failed to generate mask")
        return cloth_seg
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# CSS styling
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
'''

# Collect example images
image_dir = 'input'
image_list = []
if os.path.exists(image_dir):
    image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
    image_list.sort()
examples = [[img] for img in image_list]

with gr.Blocks(css=css) as demo:
    gr.HTML(read_content("header.html"))

    with gr.Row():
        with gr.Column():
            image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")

        with gr.Column():
            image_out = gr.Image(label="Output", elem_id="output-img")

    with gr.Row():
        gr.Examples(
            examples=examples,
            inputs=[image],
            label="Examples - Input Images",
            examples_per_page=12
        )
        btn = gr.Button("Run!", variant="primary")

    btn.click(fn=run, inputs=[image], outputs=[image_out])

    gr.HTML(
        """
        <div class="footer">
            <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
        </div>
        <div class="acknowledgments">
            <p><h4>ACKNOWLEDGEMENTS</h4></p>
            <p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
            <p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
        </div>
        """
    )

# For Hugging Face Spaces, use launch() without share=True
demo.launch()