File size: 2,411 Bytes
56a9e12 6648f32 56a9e12 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 6648f32 77304c1 84e526f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "model/cloth_segm.pth"
try:
net = load_seg_model(model_path, device=device)
palette = get_palette(4)
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def process_image(input_img):
"""Process input image and return segmentation mask"""
if input_img is None:
raise gr.Error("Please upload or capture an image first")
try:
# Convert to PIL Image if it's not already
if not isinstance(input_img, Image.Image):
input_img = Image.fromarray(input_img)
# Generate mask
output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
return output_mask
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# Create simple interface
with gr.Blocks(title="Cloth Segmentation") as demo:
gr.Markdown("""
# 🧥 Cloth Segmentation App
Upload an image or capture from your camera to get segmentation results.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
sources=["upload", "webcam"],
type="pil",
label="Input Image",
interactive=True
)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Segmentation Result",
interactive=False
)
# Examples section (optional)
example_dir = "input"
if os.path.exists(example_dir):
example_images = [
os.path.join(example_dir, f)
for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
gr.Examples(
examples=example_images,
inputs=[input_image],
outputs=[output_image],
fn=process_image,
cache_examples=True,
label="Example Images"
)
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Launch with appropriate settings
if __name__ == "__main__":
demo.launch() |