File size: 3,792 Bytes
896437a 1c6ea49 896437a 6984480 896437a 1c6ea49 6984480 896437a 6984480 896437a 6984480 896437a 6984480 896437a 9112e74 1c6ea49 9112e74 1369068 6984480 541e07e 1c6ea49 1369068 1c6ea49 9112e74 1369068 9112e74 1369068 9112e74 541e07e 9112e74 1369068 6984480 1369068 9112e74 1369068 9112e74 6984480 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def read_content(file_path: str) -> str:
"""Read file content with error handling"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
print(f"Warning: File {file_path} not found")
return ""
except Exception as e:
print(f"Error reading file {file_path}: {str(e)}")
return ""
def initialize_and_load_models():
"""Initialize and load models with error handling"""
try:
checkpoint_path = 'model/cloth_segm.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
return load_seg_model(checkpoint_path, device=device)
except Exception as e:
print(f"Error loading model: {str(e)}")
return None
net = initialize_and_load_models()
if net is None:
raise RuntimeError("Failed to load model - check logs for details")
palette = get_palette(4)
def run(img):
"""Process image with error handling"""
if img is None:
raise gr.Error("No image uploaded")
try:
cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
if cloth_seg is None:
raise gr.Error("Failed to generate mask")
return cloth_seg
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# CSS styling
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
'''
# Collect example images
image_dir = 'input'
image_list = []
if os.path.exists(image_dir):
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
image_list.sort()
examples = [[img] for img in image_list]
with gr.Blocks(css=css) as demo:
gr.HTML(read_content("header.html"))
with gr.Row():
with gr.Column():
image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img")
with gr.Row():
gr.Examples(
examples=examples,
inputs=[image],
label="Examples - Input Images",
examples_per_page=12
)
btn = gr.Button("Run!", variant="primary")
btn.click(fn=run, inputs=[image], outputs=[image_out])
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
</div>
<div class="acknowledgments">
<p><h4>ACKNOWLEDGEMENTS</h4></p>
<p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
<p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
</div>
"""
)
# For Hugging Face Spaces, use launch() without share=True
demo.launch() |