wildoctopus's picture
Update app.py
56a9e12 verified
raw
history blame
1.99 kB
import os
import torch
import gradio as gr
from PIL import Image
from process import load_seg_model, get_palette, generate_mask
# Device selection
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
"""Load model with Hugging Face Spaces compatible paths"""
model_dir = 'model'
checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth')
# Verify model exists (must be pre-uploaded to HF Spaces)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Model not found at {checkpoint_path}. "
"Please upload the model file to your Space's repository."
)
try:
net = load_seg_model(checkpoint_path, device=device)
palette = get_palette(4)
return net, palette
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# Initialize model (will fail fast if there's an issue)
net, palette = load_model()
def process_image(img: Image.Image) -> Image.Image:
"""Process input image and return segmentation mask"""
if img is None:
raise gr.Error("Please upload an image first")
try:
return generate_mask(img, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Processing failed: {str(e)}")
# Gradio interface
title = "Cloth Segmentation Demo"
description = """
Upload an image to get cloth segmentation using U2NET.
"""
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
output_image = gr.Image(type="pil", label="Segmentation Result")
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_image
)
if __name__ == "__main__":
demo.launch()