wildoctopus commited on
Commit
56a9e12
·
verified ·
1 Parent(s): e657276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -63
app.py CHANGED
@@ -1,75 +1,60 @@
1
- import gradio as gr
2
  import torch
3
- from process import load_seg_model, get_palette, generate_mask
4
  from PIL import Image
5
- import os
6
 
7
- # Initialize model
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_path = "model/cloth_segm.pth"
10
 
11
- try:
12
- net = load_seg_model(model_path, device=device)
13
- palette = get_palette(4)
14
- except Exception as e:
15
- raise RuntimeError(f"Failed to load model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def process_image(input_img):
18
  """Process input image and return segmentation mask"""
19
- if input_img is None:
20
- raise gr.Error("Please upload or capture an image first")
21
 
22
  try:
23
- # Convert to PIL Image if it's not already
24
- if not isinstance(input_img, Image.Image):
25
- input_img = Image.fromarray(input_img)
26
-
27
- # Generate mask
28
- output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
29
- return output_mask
30
  except Exception as e:
31
- raise gr.Error(f"Error processing image: {str(e)}")
32
 
33
- # Create simple interface
34
- with gr.Blocks(title="Cloth Segmentation") as demo:
35
- gr.Markdown("""
36
- # 🧥 Cloth Segmentation App
37
- Upload an image or capture from your camera to get segmentation results.
38
- """)
 
 
 
39
 
40
  with gr.Row():
41
  with gr.Column():
42
- input_image = gr.Image(
43
- sources=["upload", "webcam"],
44
- type="pil",
45
- label="Input Image",
46
- interactive=True
47
- )
48
  submit_btn = gr.Button("Process", variant="primary")
49
-
50
  with gr.Column():
51
- output_image = gr.Image(
52
- label="Segmentation Result",
53
- interactive=False
54
- )
55
-
56
- # Examples section (optional)
57
- example_dir = "input"
58
- if os.path.exists(example_dir):
59
- example_images = [
60
- os.path.join(example_dir, f)
61
- for f in os.listdir(example_dir)
62
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))
63
- ]
64
-
65
- gr.Examples(
66
- examples=example_images,
67
- inputs=[input_image],
68
- outputs=[output_image],
69
- fn=process_image,
70
- cache_examples=True,
71
- label="Example Images"
72
- )
73
 
74
  submit_btn.click(
75
  fn=process_image,
@@ -77,10 +62,5 @@ with gr.Blocks(title="Cloth Segmentation") as demo:
77
  outputs=output_image
78
  )
79
 
80
- # Launch with appropriate settings
81
  if __name__ == "__main__":
82
- demo.launch(
83
- server_name="0.0.0.0",
84
- server_port=7860,
85
- show_error=True
86
- )
 
1
+ import os
2
  import torch
3
+ import gradio as gr
4
  from PIL import Image
5
+ from process import load_seg_model, get_palette, generate_mask
6
 
7
+ # Device selection
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
9
 
10
+ def load_model():
11
+ """Load model with Hugging Face Spaces compatible paths"""
12
+ model_dir = 'model'
13
+ checkpoint_path = os.path.join(model_dir, 'cloth_segm.pth')
14
+
15
+ # Verify model exists (must be pre-uploaded to HF Spaces)
16
+ if not os.path.exists(checkpoint_path):
17
+ raise FileNotFoundError(
18
+ f"Model not found at {checkpoint_path}. "
19
+ "Please upload the model file to your Space's repository."
20
+ )
21
+
22
+ try:
23
+ net = load_seg_model(checkpoint_path, device=device)
24
+ palette = get_palette(4)
25
+ return net, palette
26
+ except Exception as e:
27
+ raise RuntimeError(f"Model loading failed: {str(e)}")
28
+
29
+ # Initialize model (will fail fast if there's an issue)
30
+ net, palette = load_model()
31
 
32
+ def process_image(img: Image.Image) -> Image.Image:
33
  """Process input image and return segmentation mask"""
34
+ if img is None:
35
+ raise gr.Error("Please upload an image first")
36
 
37
  try:
38
+ return generate_mask(img, net=net, palette=palette, device=device)
 
 
 
 
 
 
39
  except Exception as e:
40
+ raise gr.Error(f"Processing failed: {str(e)}")
41
 
42
+ # Gradio interface
43
+ title = "Cloth Segmentation Demo"
44
+ description = """
45
+ Upload an image to get cloth segmentation using U2NET.
46
+ """
47
+
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown(f"## {title}")
50
+ gr.Markdown(description)
51
 
52
  with gr.Row():
53
  with gr.Column():
54
+ input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
 
55
  submit_btn = gr.Button("Process", variant="primary")
 
56
  with gr.Column():
57
+ output_image = gr.Image(type="pil", label="Segmentation Result")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  submit_btn.click(
60
  fn=process_image,
 
62
  outputs=output_image
63
  )
64
 
 
65
  if __name__ == "__main__":
66
+ demo.launch()