wildoctopus's picture
Update app.py
77304c1 verified
raw
history blame
2.41 kB
import gradio as gr
import torch
from process import load_seg_model, get_palette, generate_mask
from PIL import Image
import os
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "model/cloth_segm.pth"
try:
net = load_seg_model(model_path, device=device)
palette = get_palette(4)
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def process_image(input_img):
"""Process input image and return segmentation mask"""
if input_img is None:
raise gr.Error("Please upload or capture an image first")
try:
# Convert to PIL Image if it's not already
if not isinstance(input_img, Image.Image):
input_img = Image.fromarray(input_img)
# Generate mask
output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
return output_mask
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# Create simple interface
with gr.Blocks(title="Cloth Segmentation") as demo:
gr.Markdown("""
# 🧥 Cloth Segmentation App
Upload an image or capture from your camera to get segmentation results.
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
sources=["upload", "webcam"],
type="pil",
label="Input Image",
interactive=True
)
submit_btn = gr.Button("Process", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Segmentation Result",
interactive=False
)
# Examples section (optional)
example_dir = "input"
if os.path.exists(example_dir):
example_images = [
os.path.join(example_dir, f)
for f in os.listdir(example_dir)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
gr.Examples(
examples=example_images,
inputs=[input_image],
outputs=[output_image],
fn=process_image,
cache_examples=True,
label="Example Images"
)
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Launch with appropriate settings
if __name__ == "__main__":
demo.launch()