File size: 1,912 Bytes
896437a 1c6ea49 896437a 6984480 896437a c396ac7 896437a 6984480 c396ac7 6984480 c396ac7 6984480 896437a c396ac7 1369068 c396ac7 6984480 c396ac7 1c6ea49 c396ac7 9112e74 c396ac7 9112e74 c396ac7 9112e74 541e07e c396ac7 1369068 c396ac7 9112e74 c396ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize models
try:
checkpoint_path = 'model/cloth_segm.pth'
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
net = load_seg_model(checkpoint_path, device=device)
palette = get_palette(4)
except Exception as e:
raise RuntimeError(f"Failed to initialize models: {str(e)}")
def run(img):
if img is None:
raise gr.Error("Please upload an image first")
try:
return generate_mask(img, net=net, palette=palette, device=device)
except Exception as e:
raise gr.Error(f"Error processing image: {str(e)}")
# Handle examples
image_dir = 'input'
examples = []
if os.path.exists(image_dir):
examples = [
[os.path.join(image_dir, f)]
for f in sorted(os.listdir(image_dir))
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
]
# Create interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
with gr.Column():
output_image = gr.Image(label="Segmentation Result")
with gr.Row():
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[output_image],
fn=run,
cache_examples=True,
label="Example Images"
)
submit_btn = gr.Button("Segment", variant="primary")
submit_btn.click(fn=run, inputs=input_image, outputs=output_image)
# Launch with appropriate settings
try:
demo.launch(server_name="0.0.0.0", server_port=7860)
except Exception as e:
print(f"Error launching app: {str(e)}")
# Fallback with sharing enabled
demo.launch(share=True) |