wildoctopus's picture
Update app.py
c396ac7 verified
raw
history blame
1.91 kB
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize models
try:
checkpoint_path = 'model/cloth_segm.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
net = load_seg_model(checkpoint_path, device=device)
palette = get_palette(4)
except Exception as e:
raise RuntimeError(f"Failed to initialize models: {str(e)}")
def run(img):
if img is None:
raise gr.Error("Please upload an image first")
try:
return generate_mask(img, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# Handle examples
image_dir = 'input'
examples = []
if os.path.exists(image_dir):
examples = [
[os.path.join(image_dir, f)]
for f in sorted(os.listdir(image_dir))
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
# Create interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
output_image = gr.Image(label="Segmentation Result")
with gr.Row():
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[output_image],
fn=run,
cache_examples=True,
label="Example Images"
)
submit_btn = gr.Button("Segment", variant="primary")
submit_btn.click(fn=run, inputs=input_image, outputs=output_image)
# Launch with appropriate settings
try:
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error launching app: {str(e)}")
# Fallback with sharing enabled
demo.launch(share=True)