wildoctopus commited on
Commit
6984480
·
verified ·
1 Parent(s): 56608f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -12
app.py CHANGED
@@ -4,22 +4,48 @@ import gradio as gr
4
  import os
5
  from process import load_seg_model, get_palette, generate_mask
6
 
7
- device = 'cpu'
8
 
9
  def read_content(file_path: str) -> str:
10
- with open(file_path, 'r', encoding='utf-8') as f:
11
- return f.read()
 
 
 
 
 
 
 
 
12
 
13
  def initialize_and_load_models():
14
- checkpoint_path = 'model/cloth_segm.pth'
15
- return load_seg_model(checkpoint_path, device=device)
 
 
 
 
 
 
 
16
 
17
  net = initialize_and_load_models()
 
 
 
18
  palette = get_palette(4)
19
 
20
  def run(img):
21
- cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
22
- return cloth_seg
 
 
 
 
 
 
 
 
23
 
24
  # CSS styling
25
  css = '''
@@ -36,8 +62,10 @@ css = '''
36
 
37
  # Collect example images
38
  image_dir = 'input'
39
- image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
40
- image_list.sort()
 
 
41
  examples = [[img] for img in image_list]
42
 
43
  with gr.Blocks(css=css) as demo:
@@ -57,7 +85,7 @@ with gr.Blocks(css=css) as demo:
57
  label="Examples - Input Images",
58
  examples_per_page=12
59
  )
60
- btn = gr.Button("Run!")
61
 
62
  btn.click(fn=run, inputs=[image], outputs=[image_out])
63
 
@@ -74,5 +102,5 @@ with gr.Blocks(css=css) as demo:
74
  """
75
  )
76
 
77
- # Ensure the app works in Hugging Face by sharing a public link
78
- demo.launch(share=True)
 
4
  import os
5
  from process import load_seg_model, get_palette, generate_mask
6
 
7
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  def read_content(file_path: str) -> str:
10
+ """Read file content with error handling"""
11
+ try:
12
+ with open(file_path, 'r', encoding='utf-8') as f:
13
+ return f.read()
14
+ except FileNotFoundError:
15
+ print(f"Warning: File {file_path} not found")
16
+ return ""
17
+ except Exception as e:
18
+ print(f"Error reading file {file_path}: {str(e)}")
19
+ return ""
20
 
21
  def initialize_and_load_models():
22
+ """Initialize and load models with error handling"""
23
+ try:
24
+ checkpoint_path = 'model/cloth_segm.pth'
25
+ if not os.path.exists(checkpoint_path):
26
+ raise FileNotFoundError(f"Model checkpoint not found at {checkpoint_path}")
27
+ return load_seg_model(checkpoint_path, device=device)
28
+ except Exception as e:
29
+ print(f"Error loading model: {str(e)}")
30
+ return None
31
 
32
  net = initialize_and_load_models()
33
+ if net is None:
34
+ raise RuntimeError("Failed to load model - check logs for details")
35
+
36
  palette = get_palette(4)
37
 
38
  def run(img):
39
+ """Process image with error handling"""
40
+ if img is None:
41
+ raise gr.Error("No image uploaded")
42
+ try:
43
+ cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
44
+ if cloth_seg is None:
45
+ raise gr.Error("Failed to generate mask")
46
+ return cloth_seg
47
+ except Exception as e:
48
+ raise gr.Error(f"Error processing image: {str(e)}")
49
 
50
  # CSS styling
51
  css = '''
 
62
 
63
  # Collect example images
64
  image_dir = 'input'
65
+ image_list = []
66
+ if os.path.exists(image_dir):
67
+ image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.lower().endswith(('.png', '.jpg', '.jpeg'))]
68
+ image_list.sort()
69
  examples = [[img] for img in image_list]
70
 
71
  with gr.Blocks(css=css) as demo:
 
85
  label="Examples - Input Images",
86
  examples_per_page=12
87
  )
88
+ btn = gr.Button("Run!", variant="primary")
89
 
90
  btn.click(fn=run, inputs=[image], outputs=[image_out])
91
 
 
102
  """
103
  )
104
 
105
+ # For Hugging Face Spaces, use launch() without share=True
106
+ demo.launch()