import gradio as gr import torch from process import load_seg_model, get_palette, generate_mask from PIL import Image import os # Initialize model device = "cuda" if torch.cuda.is_available() else "cpu" model_path = "model/cloth_segm.pth" try: net = load_seg_model(model_path, device=device) palette = get_palette(4) except Exception as e: raise RuntimeError(f"Failed to load model: {str(e)}") def process_image(input_img): """Process input image and return segmentation mask""" if input_img is None: raise gr.Error("Please upload or capture an image first") try: # Convert to PIL Image if it's not already if not isinstance(input_img, Image.Image): input_img = Image.fromarray(input_img) # Generate mask output_mask = generate_mask(input_img, net=net, palette=palette, device=device) return output_mask except Exception as e: raise gr.Error(f"Error processing image: {str(e)}") # Create simple interface with gr.Blocks(title="Cloth Segmentation") as demo: gr.Markdown(""" # 🧥 Cloth Segmentation App Upload an image or capture from your camera to get segmentation results. """) with gr.Row(): with gr.Column(): input_image = gr.Image( sources=["upload", "webcam"], type="pil", label="Input Image", interactive=True ) submit_btn = gr.Button("Process", variant="primary") with gr.Column(): output_image = gr.Image( label="Segmentation Result", interactive=False ) # Examples section (optional) example_dir = "input" if os.path.exists(example_dir): example_images = [ os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ] gr.Examples( examples=example_images, inputs=[input_image], outputs=[output_image], fn=process_image, cache_examples=True, label="Example Images" ) submit_btn.click( fn=process_image, inputs=input_image, outputs=output_image ) # Launch with appropriate settings if __name__ == "__main__": demo.launch()