Stable-X commited on
Commit
10e5f03
·
1 Parent(s): ebc67de

Improve app

Browse files
Files changed (2) hide show
  1. app.py +86 -31
  2. requirements.txt +3 -3
app.py CHANGED
@@ -19,23 +19,54 @@ class Examples(gr.helpers.Examples):
19
  self.cached_file = Path(self.cached_folder) / "log.csv"
20
  self.create()
21
 
22
- def load_predictor():
23
- """Load model predictor using torch.hub"""
24
- predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo",
25
- trust_repo=True, yoso_version='yoso-normal-v1-8-1')
26
- return predictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def process_image(
29
- predictor,
30
  path_input: str,
 
31
  data_type: str = "object"
32
  ) -> tuple:
33
- """Process single image"""
34
  if path_input is None:
35
  raise gr.Error("Please upload an image or select one from the gallery.")
 
 
 
36
 
37
  name_base = os.path.splitext(os.path.basename(path_input))[0]
38
- out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal.png")
39
 
40
  # Load and process image
41
  input_image = Image.open(path_input)
@@ -45,16 +76,15 @@ def process_image(
45
  yield [input_image, out_path]
46
 
47
  def create_demo():
48
- # Load model
49
- predictor = load_predictor()
50
 
51
- # Create processing functions for each data type
52
- process_object = spaces.GPU(functools.partial(process_image, predictor, data_type="object"))
53
 
54
  # Define markdown content
55
  HEADER_MD = """
56
  # 🎪 StableNormal Turbo
57
-
58
  <p align="center">
59
  <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
60
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
@@ -69,6 +99,8 @@ def create_demo():
69
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
70
  </a>
71
  </p>
 
 
72
  """
73
 
74
  # Create interface
@@ -92,9 +124,19 @@ def create_demo():
92
  with gr.Row():
93
  with gr.Column():
94
  object_input = gr.Image(label="Input Object Image", type="filepath")
 
 
 
 
 
 
 
 
 
95
  with gr.Row():
96
  object_submit_btn = gr.Button("Compute Normal", variant="primary")
97
  object_reset_btn = gr.Button("Reset")
 
98
  with gr.Column():
99
  object_output_slider = ImageSlider(
100
  label="Normal outputs",
@@ -106,36 +148,49 @@ def create_demo():
106
  position=0.25,
107
  )
108
 
109
- Examples(
110
- fn=process_object,
111
- examples=sorted([
112
- os.path.join("files", "object", name)
113
- for name in os.listdir(os.path.join("files", "object"))
114
- if os.path.exists(os.path.join("files", "object"))
115
- ]),
116
- inputs=[object_input],
117
- outputs=[object_output_slider],
118
- cache_examples=False,
119
- directory_name="examples_object",
120
- examples_per_page=50,
121
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # Event Handlers for Object Tab
124
  object_submit_btn.click(
125
- fn=lambda x, _: None if x else gr.Error("Please upload an image"),
126
- inputs=object_input,
127
  outputs=None,
128
  queue=False,
129
  ).success(
130
  fn=process_object,
131
- inputs=object_input,
132
  outputs=[object_output_slider],
133
  )
134
 
135
  object_reset_btn.click(
136
- fn=lambda: (None, DEFAULT_SHARPNESS, None),
137
  inputs=[],
138
- outputs=[object_input, object_output_slider],
139
  queue=False,
140
  )
141
 
 
19
  self.cached_file = Path(self.cached_folder) / "log.csv"
20
  self.create()
21
 
22
+ # Global variable to store loaded predictors
23
+ predictors = {}
24
+
25
+ # Available model versions
26
+ MODEL_VERSIONS = {
27
+ "v0.3": "yoso-normal-v0-3",
28
+ "v1.0": "yoso-normal-v1-0",
29
+ "v1.5": "yoso-normal-v1-5",
30
+ "v1.8.1": "yoso-normal-v1-8-1"
31
+ }
32
+
33
+ def load_predictor(version: str = "v1.8.1"):
34
+ """Load model predictor using torch.hub with specified version"""
35
+ if version not in predictors:
36
+ yoso_version = MODEL_VERSIONS[version]
37
+ print(f"Loading StableNormal with {yoso_version}...")
38
+ predictor = torch.hub.load("Stable-X/StableNormal", "StableNormal_turbo",
39
+ trust_repo=True, yoso_version=yoso_version)
40
+ predictors[version] = predictor
41
+ print(f"Successfully loaded {version}")
42
+ return predictors[version]
43
+
44
+ def precache_all_predictors():
45
+ """Precache all model predictors at startup"""
46
+ print("Precaching all StableNormal predictors...")
47
+ for version in MODEL_VERSIONS.keys():
48
+ print(f"Precaching {version}...")
49
+ try:
50
+ load_predictor(version)
51
+ print(f"✓ Successfully precached {version}")
52
+ except Exception as e:
53
+ print(f"✗ Failed to precache {version}: {e}")
54
+ print("Finished precaching all predictors.")
55
 
56
  def process_image(
 
57
  path_input: str,
58
+ version: str = "v1.8.1",
59
  data_type: str = "object"
60
  ) -> tuple:
61
+ """Process single image with specified model version"""
62
  if path_input is None:
63
  raise gr.Error("Please upload an image or select one from the gallery.")
64
+
65
+ # Load the predictor for the specified version
66
+ predictor = load_predictor(version)
67
 
68
  name_base = os.path.splitext(os.path.basename(path_input))[0]
69
+ out_path = os.path.join(tempfile.mkdtemp(), f"{name_base}_normal_{version.replace('.', '_')}.png")
70
 
71
  # Load and process image
72
  input_image = Image.open(path_input)
 
76
  yield [input_image, out_path]
77
 
78
  def create_demo():
79
+ # Precache all predictors before creating the demo
80
+ precache_all_predictors()
81
 
82
+ # Create processing function
83
+ process_object = spaces.GPU(process_image)
84
 
85
  # Define markdown content
86
  HEADER_MD = """
87
  # 🎪 StableNormal Turbo
 
88
  <p align="center">
89
  <a title="Website" href="https://stable-x.github.io/StableNormal/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
90
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
 
99
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
100
  </a>
101
  </p>
102
+
103
+ Select between different YOSO Normal model versions. Each version may have different performance characteristics and quality trade-offs.
104
  """
105
 
106
  # Create interface
 
124
  with gr.Row():
125
  with gr.Column():
126
  object_input = gr.Image(label="Input Object Image", type="filepath")
127
+
128
+ # Model version selector
129
+ version_dropdown = gr.Dropdown(
130
+ choices=list(MODEL_VERSIONS.keys()),
131
+ value="v1.8.1",
132
+ label="Model Version",
133
+ info="Select YOSO Normal model version"
134
+ )
135
+
136
  with gr.Row():
137
  object_submit_btn = gr.Button("Compute Normal", variant="primary")
138
  object_reset_btn = gr.Button("Reset")
139
+
140
  with gr.Column():
141
  object_output_slider = ImageSlider(
142
  label="Normal outputs",
 
148
  position=0.25,
149
  )
150
 
151
+ # Model version info
152
+ with gr.Row():
153
+ gr.Markdown("""
154
+ **Model Version Information:**
155
+ - **v0.3**: Camera Ready Version
156
+ - **v1.0**: Improve stability, but poor sharpness
157
+ - **v1.5**: Enhanced performance and accuracy
158
+ - **v1.8.1**: Latest version with best sharpness (default)
159
+
160
+ *All models are precached and ready for instant switching.*
161
+ """)
162
+
163
+ # Examples section
164
+ if os.path.exists(os.path.join("files", "object")):
165
+ Examples(
166
+ fn=lambda img, ver: process_object(img, ver),
167
+ examples=sorted([
168
+ [os.path.join("files", "object", name), "v1.8.1"]
169
+ for name in os.listdir(os.path.join("files", "object"))
170
+ ]),
171
+ inputs=[object_input, version_dropdown],
172
+ outputs=[object_output_slider],
173
+ cache_examples=False,
174
+ directory_name="examples_object",
175
+ examples_per_page=50,
176
+ )
177
 
178
  # Event Handlers for Object Tab
179
  object_submit_btn.click(
180
+ fn=lambda x, v: None if x else gr.Error("Please upload an image"),
181
+ inputs=[object_input, version_dropdown],
182
  outputs=None,
183
  queue=False,
184
  ).success(
185
  fn=process_object,
186
+ inputs=[object_input, version_dropdown],
187
  outputs=[object_output_slider],
188
  )
189
 
190
  object_reset_btn.click(
191
+ fn=lambda: (None, "v1.8.1", None),
192
  inputs=[],
193
+ outputs=[object_input, version_dropdown, object_output_slider],
194
  queue=False,
195
  )
196
 
requirements.txt CHANGED
@@ -30,9 +30,9 @@ filelock==3.14.0
30
  fonttools==4.53.0
31
  frozenlist==1.4.1
32
  fsspec==2024.3.1
33
- gradio==4.32.2
34
- gradio_client==0.17.0
35
- gradio_imageslider==0.0.20
36
  h11==0.14.0
37
  httpcore==1.0.5
38
  httptools==0.6.1
 
30
  fonttools==4.53.0
31
  frozenlist==1.4.1
32
  fsspec==2024.3.1
33
+ gradio==4.44.1
34
+ gradio_client
35
+ gradio_imageslider
36
  h11==0.14.0
37
  httpcore==1.0.5
38
  httptools==0.6.1