File size: 2,681 Bytes
896437a 1c6ea49 896437a 1c6ea49 1369068 896437a 1369068 896437a 9112e74 1c6ea49 9112e74 1369068 1c6ea49 541e07e 1c6ea49 1369068 1c6ea49 9112e74 1369068 9112e74 1369068 9112e74 541e07e 9112e74 1369068 9112e74 1369068 9112e74 1369068 9112e74 22dd3c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask
device = 'cpu'
def read_content(file_path: str) -> str:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def initialize_and_load_models():
checkpoint_path = 'model/cloth_segm.pth'
return load_seg_model(checkpoint_path, device=device)
net = initialize_and_load_models()
palette = get_palette(4)
def run(img):
cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
return cloth_seg
# CSS styling
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
'''
# Collect example images
image_dir = 'input'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
image_list.sort()
examples = [[img] for img in image_list]
with gr.Blocks(css=css) as demo:
gr.HTML(read_content("header.html"))
with gr.Row():
with gr.Column():
image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img")
with gr.Row():
gr.Examples(
examples=examples,
inputs=[image],
label="Examples - Input Images",
examples_per_page=12
)
btn = gr.Button("Run!")
btn.click(fn=run, inputs=[image], outputs=[image_out])
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
</div>
<div class="acknowledgments">
<p><h4>ACKNOWLEDGEMENTS</h4></p>
<p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
<p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
</div>
"""
)
# Ensure the app works in Hugging Face by sharing a public link
demo.launch(share=True)
|