wildoctopus commited on
Commit
77304c1
·
verified ·
1 Parent(s): 6648f32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -33
app.py CHANGED
@@ -4,55 +4,79 @@ from process import load_seg_model, get_palette, generate_mask
4
  from PIL import Image
5
  import os
6
 
7
- # Model initialization
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- def load_models():
11
- try:
12
- net = load_seg_model("model/cloth_segm.pth", device=device)
13
- palette = get_palette(4)
14
- return net, palette
15
- except Exception as e:
16
- raise gr.Error(f"Model failed to load: {str(e)}")
17
 
18
- net, palette = load_models()
19
-
20
- def predict(image):
21
- if image is None:
22
  raise gr.Error("Please upload or capture an image first")
 
23
  try:
24
- if not isinstance(image, Image.Image):
25
- image = Image.fromarray(image)
26
- return generate_mask(image, net=net, palette=palette, device=device)
 
 
 
 
27
  except Exception as e:
28
- raise gr.Error(f"Processing error: {str(e)}")
29
 
30
- # Interface
31
  with gr.Blocks(title="Cloth Segmentation") as demo:
32
- gr.Markdown("## 👕 Cloth Segmentation Tool")
 
 
 
33
 
34
  with gr.Row():
35
  with gr.Column():
36
- img_input = gr.Image(sources=["upload", "webcam"],
37
- type="pil",
38
- label="Input Image")
39
- btn = gr.Button("Generate Mask", variant="primary")
 
 
 
40
 
41
  with gr.Column():
42
- img_output = gr.Image(label="Segmentation Result")
 
 
 
43
 
44
- # Optional examples
45
- if os.path.exists("examples"):
 
 
 
 
 
 
 
46
  gr.Examples(
47
- examples=[os.path.join("examples", f) for f in os.listdir("examples")
48
- if f.endswith(('.png','.jpg','.jpeg'))],
49
- inputs=img_input,
50
- outputs=img_output,
51
- fn=predict,
52
- cache_examples=True
53
  )
 
 
 
 
 
 
54
 
55
- btn.click(predict, inputs=img_input, outputs=img_output)
56
-
57
  if __name__ == "__main__":
58
  demo.launch()
 
4
  from PIL import Image
5
  import os
6
 
7
+ # Initialize model
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model_path = "model/cloth_segm.pth"
10
 
11
+ try:
12
+ net = load_seg_model(model_path, device=device)
13
+ palette = get_palette(4)
14
+ except Exception as e:
15
+ raise RuntimeError(f"Failed to load model: {str(e)}")
 
 
16
 
17
+ def process_image(input_img):
18
+ """Process input image and return segmentation mask"""
19
+ if input_img is None:
 
20
  raise gr.Error("Please upload or capture an image first")
21
+
22
  try:
23
+ # Convert to PIL Image if it's not already
24
+ if not isinstance(input_img, Image.Image):
25
+ input_img = Image.fromarray(input_img)
26
+
27
+ # Generate mask
28
+ output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
29
+ return output_mask
30
  except Exception as e:
31
+ raise gr.Error(f"Error processing image: {str(e)}")
32
 
33
+ # Create simple interface
34
  with gr.Blocks(title="Cloth Segmentation") as demo:
35
+ gr.Markdown("""
36
+ # 🧥 Cloth Segmentation App
37
+ Upload an image or capture from your camera to get segmentation results.
38
+ """)
39
 
40
  with gr.Row():
41
  with gr.Column():
42
+ input_image = gr.Image(
43
+ sources=["upload", "webcam"],
44
+ type="pil",
45
+ label="Input Image",
46
+ interactive=True
47
+ )
48
+ submit_btn = gr.Button("Process", variant="primary")
49
 
50
  with gr.Column():
51
+ output_image = gr.Image(
52
+ label="Segmentation Result",
53
+ interactive=False
54
+ )
55
 
56
+ # Examples section (optional)
57
+ example_dir = "input"
58
+ if os.path.exists(example_dir):
59
+ example_images = [
60
+ os.path.join(example_dir, f)
61
+ for f in os.listdir(example_dir)
62
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))
63
+ ]
64
+
65
  gr.Examples(
66
+ examples=example_images,
67
+ inputs=[input_image],
68
+ outputs=[output_image],
69
+ fn=process_image,
70
+ cache_examples=True,
71
+ label="Example Images"
72
  )
73
+
74
+ submit_btn.click(
75
+ fn=process_image,
76
+ inputs=input_image,
77
+ outputs=output_image
78
+ )
79
 
80
+ # Launch with appropriate settings
 
81
  if __name__ == "__main__":
82
  demo.launch()