|
import PIL |
|
import torch |
|
import gradio as gr |
|
import os |
|
from process import load_seg_model, get_palette, generate_mask |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
try: |
|
checkpoint_path = 'model/cloth_segm.pth' |
|
if not os.path.exists(checkpoint_path): |
|
raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}") |
|
net = load_seg_model(checkpoint_path, device=device) |
|
palette = get_palette(4) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to initialize models: {str(e)}") |
|
|
|
def run(img): |
|
if img is None: |
|
raise gr.Error("Please upload an image first") |
|
try: |
|
return generate_mask(img, net=net, palette=palette, device=device) |
|
except Exception as e: |
|
raise gr.Error(f"Error processing image: {str(e)}") |
|
|
|
|
|
image_dir = 'input' |
|
examples = [] |
|
if os.path.exists(image_dir): |
|
examples = [ |
|
[os.path.join(image_dir, f)] |
|
for f in sorted(os.listdir(image_dir)) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg')) |
|
] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Input Image", type="pil") |
|
with gr.Column(): |
|
output_image = gr.Image(label="Segmentation Result") |
|
|
|
with gr.Row(): |
|
gr.Examples( |
|
examples=examples, |
|
inputs=[input_image], |
|
outputs=[output_image], |
|
fn=run, |
|
cache_examples=True, |
|
label="Example Images" |
|
) |
|
|
|
submit_btn = gr.Button("Segment", variant="primary") |
|
submit_btn.click(fn=run, inputs=input_image, outputs=output_image) |
|
|
|
|
|
try: |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
except Exception as e: |
|
print(f"Error launching app: {str(e)}") |
|
|
|
demo.launch(share=True) |