import os import torch import gradio as gr from PIL import Image from process import load_seg_model, get_palette, generate_mask # Device selection device = 'cuda' if torch.cuda.is_available() else 'cpu' def load_model(): """Load model with Hugging Face Spaces compatible paths""" model_dir = 'model' checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth') # Verify model exists (must be pre-uploaded to HF Spaces) if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Model not found at {checkpoint_path}. " "Please upload the model file to your Space's repository." ) try: net = load_seg_model(checkpoint_path, device=device) palette = get_palette(4) return net, palette except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") # Initialize model (will fail fast if there's an issue) net, palette = load_model() def process_image(img: Image.Image) -> Image.Image: """Process input image and return segmentation mask""" if img is None: raise gr.Error("Please upload an image first") try: return generate_mask(img, net=net, palette=palette, device=device) except Exception as e: raise gr.Error(f"Processing failed: {str(e)}") # Gradio interface title = "Cloth Segmentation Demo" description = """ Upload an image to get cloth segmentation using U2NET. """ with gr.Blocks() as demo: gr.Markdown(f"## {title}") gr.Markdown(description) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") submit_btn = gr.Button("Process", variant="primary") with gr.Column(): output_image = gr.Image(type="pil", label="Segmentation Result") submit_btn.click( fn=process_image, inputs=input_image, outputs=output_image ) if __name__ == "__main__": demo.launch()