import gradio as gr import torch from process import load_seg_model, get_palette, generate_mask from PIL import Image import os # Model initialization device = "cuda" if torch.cuda.is_available() else "cpu" def load_models(): try: net = load_seg_model("model/cloth_segm.pth", device=device) palette = get_palette(4) return net, palette except Exception as e: raise gr.Error(f"Model failed to load: {str(e)}") net, palette = load_models() def predict(image): if image is None: raise gr.Error("Please upload or capture an image first") try: if not isinstance(image, Image.Image): image = Image.fromarray(image) return generate_mask(image, net=net, palette=palette, device=device) except Exception as e: raise gr.Error(f"Processing error: {str(e)}") # Interface with gr.Blocks(title="Cloth Segmentation") as demo: gr.Markdown("## 👕 Cloth Segmentation Tool") with gr.Row(): with gr.Column(): img_input = gr.Image(sources=["upload", "webcam"], type="pil", label="Input Image") btn = gr.Button("Generate Mask", variant="primary") with gr.Column(): img_output = gr.Image(label="Segmentation Result") # Optional examples if os.path.exists("examples"): gr.Examples( examples=[os.path.join("examples", f) for f in os.listdir("examples") if f.endswith(('.png','.jpg','.jpeg'))], inputs=img_input, outputs=img_output, fn=predict, cache_examples=True ) btn.click(predict, inputs=img_input, outputs=img_output) if __name__ == "__main__": demo.launch()