import PIL import torch import gradio as gr import os from process import load_seg_model, get_palette, generate_mask device = 'cuda' if torch.cuda.is_available() else 'cpu' # Initialize models try: checkpoint_path = 'model/cloth_segm.pth' if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}") net = load_seg_model(checkpoint_path, device=device) palette = get_palette(4) except Exception as e: raise RuntimeError(f"Failed to initialize models: {str(e)}") def run(img): if img is None: raise gr.Error("Please upload an image first") try: return generate_mask(img, net=net, palette=palette, device=device) except Exception as e: raise gr.Error(f"Error processing image: {str(e)}") # Handle examples image_dir = 'input' examples = [] if os.path.exists(image_dir): examples = [ [os.path.join(image_dir, f)] for f in sorted(os.listdir(image_dir)) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ] # Create interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") with gr.Column(): output_image = gr.Image(label="Segmentation Result") with gr.Row(): gr.Examples( examples=examples, inputs=[input_image], outputs=[output_image], fn=run, cache_examples=True, label="Example Images" ) submit_btn = gr.Button("Segment", variant="primary") submit_btn.click(fn=run, inputs=input_image, outputs=output_image) # Launch with appropriate settings try: demo.launch(server_name="0.0.0.0", server_port=7860) except Exception as e: print(f"Error launching app: {str(e)}") # Fallback with sharing enabled demo.launch(share=True)