|
import os |
|
import torch |
|
import gradio as gr |
|
from PIL import Image |
|
from process import load_seg_model, get_palette, generate_mask |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
def load_model(): |
|
"""Load model with Hugging Face Spaces compatible paths""" |
|
model_dir = 'model' |
|
checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth') |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError( |
|
f"Model not found at {checkpoint_path}. " |
|
"Please upload the model file to your Space's repository." |
|
) |
|
|
|
try: |
|
net = load_seg_model(checkpoint_path, device=device) |
|
palette = get_palette(4) |
|
return net, palette |
|
except Exception as e: |
|
raise RuntimeError(f"Model loading failed: {str(e)}") |
|
|
|
|
|
net, palette = load_model() |
|
|
|
def process_image(img: Image.Image) -> Image.Image: |
|
"""Process input image and return segmentation mask""" |
|
if img is None: |
|
raise gr.Error("Please upload an image first") |
|
|
|
try: |
|
return generate_mask(img, net=net, palette=palette, device=device) |
|
except Exception as e: |
|
raise gr.Error(f"Processing failed: {str(e)}") |
|
|
|
|
|
title = "Cloth Segmentation Demo" |
|
description = """ |
|
Upload an image to get cloth segmentation using U2NET. |
|
""" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f"## {title}") |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Input Image") |
|
submit_btn = gr.Button("Process", variant="primary") |
|
with gr.Column(): |
|
output_image = gr.Image(type="pil", label="Segmentation Result") |
|
|
|
submit_btn.click( |
|
fn=process_image, |
|
inputs=input_image, |
|
outputs=output_image |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |