wildoctopus commited on
Commit
b882389
·
verified ·
1 Parent(s): 56a9e12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -62
app.py CHANGED
@@ -1,66 +1,6 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- from PIL import Image
5
- from process import load_seg_model, get_palette, generate_mask
6
-
7
- # Device selection
8
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
-
10
- def load_model():
11
- """Load model with Hugging Face Spaces compatible paths"""
12
- model_dir = 'model'
13
- checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth')
14
-
15
- # Verify model exists (must be pre-uploaded to HF Spaces)
16
- if not os.path.exists(checkpoint_path):
17
- raise FileNotFoundError(
18
- f"Model not found at {checkpoint_path}. "
19
- "Please upload the model file to your Space's repository."
20
- )
21
-
22
- try:
23
- net = load_seg_model(checkpoint_path, device=device)
24
- palette = get_palette(4)
25
- return net, palette
26
- except Exception as e:
27
- raise RuntimeError(f"Model loading failed: {str(e)}")
28
-
29
- # Initialize model (will fail fast if there's an issue)
30
- net, palette = load_model()
31
-
32
- def process_image(img: Image.Image) -> Image.Image:
33
- """Process input image and return segmentation mask"""
34
- if img is None:
35
- raise gr.Error("Please upload an image first")
36
-
37
- try:
38
- return generate_mask(img, net=net, palette=palette, device=device)
39
- except Exception as e:
40
- raise gr.Error(f"Processing failed: {str(e)}")
41
-
42
- # Gradio interface
43
- title = "Cloth Segmentation Demo"
44
- description = """
45
- Upload an image to get cloth segmentation using U2NET.
46
- """
47
 
48
  with gr.Blocks() as demo:
49
- gr.Markdown(f"## {title}")
50
- gr.Markdown(description)
51
-
52
- with gr.Row():
53
- with gr.Column():
54
- input_image = gr.Image(type="pil", label="Input Image")
55
- submit_btn = gr.Button("Process", variant="primary")
56
- with gr.Column():
57
- output_image = gr.Image(type="pil", label="Segmentation Result")
58
-
59
- submit_btn.click(
60
- fn=process_image,
61
- inputs=input_image,
62
- outputs=output_image
63
- )
64
 
65
- if __name__ == "__main__":
66
- demo.launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  with gr.Blocks() as demo:
4
+ gr.Markdown("# Welcome to Gradio! 🎉")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ demo.launch()