File size: 2,681 Bytes
896437a
 
 
1c6ea49
896437a
 
 
 
1c6ea49
 
1369068
896437a
 
 
1369068
896437a
 
 
 
 
 
 
 
9112e74
1c6ea49
 
 
 
 
 
 
 
 
 
 
 
9112e74
1369068
 
1c6ea49
541e07e
1c6ea49
1369068
1c6ea49
 
9112e74
 
 
1369068
9112e74
 
1369068
9112e74
 
541e07e
9112e74
 
 
1369068
9112e74
1369068
9112e74
1369068
 
 
 
 
 
 
 
 
 
 
 
 
9112e74
22dd3c0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask

device = 'cpu'

def read_content(file_path: str) -> str:
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

def initialize_and_load_models():
    checkpoint_path = 'model/cloth_segm.pth'
    return load_seg_model(checkpoint_path, device=device)

net = initialize_and_load_models()
palette = get_palette(4)

def run(img):
    cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
    return cloth_seg

# CSS styling
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
'''

# Collect example images
image_dir = 'input'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
image_list.sort()
examples = [[img] for img in image_list]

with gr.Blocks(css=css) as demo:
    gr.HTML(read_content("header.html"))

    with gr.Row():
        with gr.Column():
            image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")

        with gr.Column():
            image_out = gr.Image(label="Output", elem_id="output-img")

    with gr.Row():
        gr.Examples(
            examples=examples,
            inputs=[image],
            label="Examples - Input Images",
            examples_per_page=12
        )
        btn = gr.Button("Run!")

    btn.click(fn=run, inputs=[image], outputs=[image_out])

    gr.HTML(
        """
        <div class="footer">
            <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
        </div>
        <div class="acknowledgments">
            <p><h4>ACKNOWLEDGEMENTS</h4></p>
            <p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
            <p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
        </div>
        """
    )

# Ensure the app works in Hugging Face by sharing a public link
demo.launch(share=True)