|
import gradio as gr |
|
import torch |
|
from process import load_seg_model, get_palette, generate_mask |
|
from PIL import Image |
|
import os |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_path = "model/cloth_segm.pth" |
|
|
|
try: |
|
net = load_seg_model(model_path, device=device) |
|
palette = get_palette(4) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load model: {str(e)}") |
|
|
|
def process_image(input_img): |
|
"""Process input image and return segmentation mask""" |
|
if input_img is None: |
|
raise gr.Error("Please upload or capture an image first") |
|
|
|
try: |
|
|
|
if not isinstance(input_img, Image.Image): |
|
input_img = Image.fromarray(input_img) |
|
|
|
|
|
output_mask = generate_mask(input_img, net=net, palette=palette, device=device) |
|
return output_mask |
|
except Exception as e: |
|
raise gr.Error(f"Error processing image: {str(e)}") |
|
|
|
|
|
with gr.Blocks(title="Cloth Segmentation") as demo: |
|
gr.Markdown(""" |
|
# 🧥 Cloth Segmentation App |
|
Upload an image or capture from your camera to get segmentation results. |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image( |
|
sources=["upload", "webcam"], |
|
type="pil", |
|
label="Input Image", |
|
interactive=True |
|
) |
|
submit_btn = gr.Button("Process", variant="primary") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image( |
|
label="Segmentation Result", |
|
interactive=False |
|
) |
|
|
|
|
|
example_dir = "input" |
|
if os.path.exists(example_dir): |
|
example_images = [ |
|
os.path.join(example_dir, f) |
|
for f in os.listdir(example_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg')) |
|
] |
|
|
|
gr.Examples( |
|
examples=example_images, |
|
inputs=[input_image], |
|
outputs=[output_image], |
|
fn=process_image, |
|
cache_examples=True, |
|
label="Example Images" |
|
) |
|
|
|
submit_btn.click( |
|
fn=process_image, |
|
inputs=input_image, |
|
outputs=output_image |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |