wildoctopus commited on
Commit
0038d2b
·
verified ·
1 Parent(s): 77304c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -66
app.py CHANGED
@@ -1,82 +1,106 @@
1
- import gradio as gr
2
  import torch
3
- from process import load_seg_model, get_palette, generate_mask
4
- from PIL import Image
5
  import os
 
6
 
7
- # Initialize model
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_path = "model/cloth_segm.pth"
10
 
11
- try:
12
- net = load_seg_model(model_path, device=device)
13
- palette = get_palette(4)
14
- except Exception as e:
15
- raise RuntimeError(f"Failed to load model: {str(e)}")
 
 
 
 
 
 
16
 
17
- def process_image(input_img):
18
- """Process input image and return segmentation mask"""
19
- if input_img is None:
20
- raise gr.Error("Please upload or capture an image first")
21
-
22
  try:
23
- # Convert to PIL Image if it's not already
24
- if not isinstance(input_img, Image.Image):
25
- input_img = Image.fromarray(input_img)
26
-
27
- # Generate mask
28
- output_mask = generate_mask(input_img, net=net, palette=palette, device=device)
29
- return output_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  except Exception as e:
31
  raise gr.Error(f"Error processing image: {str(e)}")
32
 
33
- # Create simple interface
34
- with gr.Blocks(title="Cloth Segmentation") as demo:
35
- gr.Markdown("""
36
- # 🧥 Cloth Segmentation App
37
- Upload an image or capture from your camera to get segmentation results.
38
- """)
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  with gr.Row():
41
  with gr.Column():
42
- input_image = gr.Image(
43
- sources=["upload", "webcam"],
44
- type="pil",
45
- label="Input Image",
46
- interactive=True
47
- )
48
- submit_btn = gr.Button("Process", variant="primary")
49
-
50
  with gr.Column():
51
- output_image = gr.Image(
52
- label="Segmentation Result",
53
- interactive=False
54
- )
55
-
56
- # Examples section (optional)
57
- example_dir = "input"
58
- if os.path.exists(example_dir):
59
- example_images = [
60
- os.path.join(example_dir, f)
61
- for f in os.listdir(example_dir)
62
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))
63
- ]
64
-
65
  gr.Examples(
66
- examples=example_images,
67
- inputs=[input_image],
68
- outputs=[output_image],
69
- fn=process_image,
70
- cache_examples=True,
71
- label="Example Images"
72
  )
73
-
74
- submit_btn.click(
75
- fn=process_image,
76
- inputs=input_image,
77
- outputs=output_image
 
 
 
 
 
 
 
 
 
 
78
  )
79
 
80
- # Launch with appropriate settings
81
- if __name__ == "__main__":
82
- demo.launch()
 
1
+ import PIL
2
  import torch
3
+ import gradio as gr
 
4
  import os
5
+ from process import load_seg_model, get_palette, generate_mask
6
 
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
8
 
9
+ def read_content(file_path: str) -> str:
10
+ """Read file content with error handling"""
11
+ try:
12
+ with open(file_path, 'r', encoding='utf-8') as f:
13
+ return f.read()
14
+ except FileNotFoundError:
15
+ print(f"Warning: File {file_path} not found")
16
+ return ""
17
+ except Exception as e:
18
+ print(f"Error reading file {file_path}: {str(e)}")
19
+ return ""
20
 
21
+ def initialize_and_load_models():
22
+ """Initialize and load models with error handling"""
 
 
 
23
  try:
24
+ checkpoint_path = 'model/cloth_segm.pth'
25
+ if not os.path.exists(checkpoint_path):
26
+ raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
27
+ return load_seg_model(checkpoint_path, device=device)
28
+ except Exception as e:
29
+ print(f"Error loading model: {str(e)}")
30
+ return None
31
+
32
+ net = initialize_and_load_models()
33
+ if net is None:
34
+ raise RuntimeError("Failed to load model - check logs for details")
35
+
36
+ palette = get_palette(4)
37
+
38
+ def run(img):
39
+ """Process image with error handling"""
40
+ if img is None:
41
+ raise gr.Error("No image uploaded")
42
+ try:
43
+ cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
44
+ if cloth_seg is None:
45
+ raise gr.Error("Failed to generate mask")
46
+ return cloth_seg
47
  except Exception as e:
48
  raise gr.Error(f"Error processing image: {str(e)}")
49
 
50
+ # CSS styling
51
+ css = '''
52
+ .container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
53
+ #image_upload{min-height:400px}
54
+ #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
55
+ .footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
56
+ .footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
57
+ .dark .footer {border-color: #303030}
58
+ .dark .footer>p {background: #0b0f19}
59
+ .acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
60
+ #image_upload .touch-none{display: flex}
61
+ '''
62
+
63
+ # Collect example images
64
+ image_dir = 'input'
65
+ image_list = []
66
+ if os.path.exists(image_dir):
67
+ image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
68
+ image_list.sort()
69
+ examples = [[img] for img in image_list]
70
+
71
+ with gr.Blocks(css=css) as demo:
72
+ gr.HTML(read_content("header.html"))
73
+
74
  with gr.Row():
75
  with gr.Column():
76
+ image = gr.Image(elem_id="image_upload", type="pil", label="Input Image")
77
+
 
 
 
 
 
 
78
  with gr.Column():
79
+ image_out = gr.Image(label="Output", elem_id="output-img")
80
+
81
+ with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
82
  gr.Examples(
83
+ examples=examples,
84
+ inputs=[image],
85
+ label="Examples - Input Images",
86
+ examples_per_page=12
 
 
87
  )
88
+ btn = gr.Button("Run!", variant="primary")
89
+
90
+ btn.click(fn=run, inputs=[image], outputs=[image_out])
91
+
92
+ gr.HTML(
93
+ """
94
+ <div class="footer">
95
+ <p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face</p>
96
+ </div>
97
+ <div class="acknowledgments">
98
+ <p><h4>ACKNOWLEDGEMENTS</h4></p>
99
+ <p>U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" target="_blank">Xuebin Qin</a>.</p>
100
+ <p>Codes modified from <a href="https://github.com/levindabhi/cloth-segmentation" target="_blank">levindabhi/cloth-segmentation</a></p>
101
+ </div>
102
+ """
103
  )
104
 
105
+ # For Hugging Face Spaces, use launch() without share=True
106
+ demo.launch()