FQiao commited on
Commit
97d15c9
·
verified ·
1 Parent(s): 984c5f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -31
app.py CHANGED
@@ -11,14 +11,13 @@ from torch import Tensor
11
  from genstereo import GenStereo, AdaptiveFusionLayer
12
  import ssl
13
  from huggingface_hub import hf_hub_download
14
- import spaces
15
 
16
  from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
17
  ssl._create_default_https_context = ssl._create_unverified_context
18
 
19
- IMAGE_SIZE = 512
20
  DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
21
- CHECKPOINT_NAME = 'genstereo'
22
 
23
  def download_models():
24
  models = [
@@ -39,7 +38,14 @@ def download_models():
39
  {
40
  'repo': 'FQiao/GenStereo',
41
  'sub': None,
42
- 'dst': 'checkpoints/genstereo',
 
 
 
 
 
 
 
43
  'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
44
  'token': None
45
  },
@@ -86,13 +92,13 @@ def get_dam2_model():
86
  return dam2
87
 
88
  # GenStereo
89
- def get_genstereo_model():
90
- genwarp_cfg = dict(
91
  pretrained_model_path='checkpoints',
92
  checkpoint_name=CHECKPOINT_NAME,
93
  half_precision_weights=True
94
  )
95
- genstereo = GenStereo(cfg=genwarp_cfg, device=DEVICE)
96
  return genstereo
97
 
98
  # Adaptive Fusion
@@ -128,14 +134,32 @@ with tempfile.TemporaryDirectory() as tmpdir:
128
  src_depth = gr.State()
129
 
130
  # Callbacks
131
- @spaces.GPU()
132
- def cb_mde(image_file: str):
 
 
 
 
 
 
 
 
 
 
 
133
  if not image_file:
134
  # Return None if no image is provided (e.g., when file is cleared).
135
  return None, None, None, None
136
 
137
  image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
138
- image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
 
 
 
 
 
 
 
139
 
140
  image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
141
 
@@ -147,12 +171,11 @@ with tempfile.TemporaryDirectory() as tmpdir:
147
 
148
  return image, depth_image, image, depth
149
 
150
- @spaces.GPU()
151
- def cb_generate(image, depth: Tensor, scale_factor):
152
  norm_disp = normalize_disp(depth.cuda())
153
  disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
154
 
155
- genstereo = get_genstereo_model()
156
  fusion_model = get_fusion_model()
157
 
158
  renders = genstereo(
@@ -174,27 +197,44 @@ with tempfile.TemporaryDirectory() as tmpdir:
174
  # Blocks.
175
  gr.Markdown(
176
  """
177
- # GenStereo: Towards Open-World Generation of Stereo Images and Unsupervised Matching
178
-
179
- [Project Web](https://qjizhi.github.io/genstereo) · [Spaces Demo](https://huggingface.co/spaces/FQiao/GenStereo) · [GitHub Repo](https://github.com/Qjizhi/GenStereo) · [Checkpoints](https://huggingface.co/FQiao/GenStereo/tree/main) · [Paper](https://arxiv.org/abs/2503.12720)
180
-
 
 
 
181
  ## Introduction
182
- This is an official demo for the paper "GenStereo: Towards Open-World Generation of Stereo Images and Unsupervised Matching". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
183
-
184
  ## How to Use
185
- 1. Upload a reference image to "Left Image"
186
- - You can also select an image from "Examples"
187
- 2. Hit "Generate a right image" button and check the result
 
 
 
 
 
188
  """
189
  )
190
- file = gr.File(label='Left', file_types=['image'])
191
- examples = gr.Examples(
192
- examples=['./assets/COCO_val2017_000000070229.jpg',
193
- './assets/COCO_val2017_000000092839.jpg',
194
- './assets/KITTI2015_000003_10.png',
195
- './assets/KITTI2015_000147_10.png'],
196
- inputs=file
197
  )
 
 
 
 
 
 
 
 
 
 
 
198
  with gr.Row():
199
  image_widget = gr.Image(
200
  label='Left Image', type='filepath',
@@ -221,14 +261,23 @@ with tempfile.TemporaryDirectory() as tmpdir:
221
  )
222
 
223
  # Events
 
 
 
 
 
 
 
 
 
224
  file.change(
225
  fn=cb_mde,
226
- inputs=file,
227
  outputs=[image_widget, depth_widget, src_image, src_depth]
228
  )
229
  button.click(
230
  fn=cb_generate,
231
- inputs=[src_image, src_depth, scale_slider],
232
  outputs=[warped_widget, gen_widget]
233
  )
234
 
 
11
  from genstereo import GenStereo, AdaptiveFusionLayer
12
  import ssl
13
  from huggingface_hub import hf_hub_download
 
14
 
15
  from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
16
  ssl._create_default_https_context = ssl._create_unverified_context
17
 
18
+ IMAGE_SIZE = 768
19
  DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
20
+ CHECKPOINT_NAME = 'genstereo-v2.1'
21
 
22
  def download_models():
23
  models = [
 
38
  {
39
  'repo': 'FQiao/GenStereo',
40
  'sub': None,
41
+ 'dst': 'checkpoints/genstereo-v1.5',
42
+ 'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
43
+ 'token': None
44
+ },
45
+ {
46
+ 'repo': 'FQiao/GenStereo-sd2.1',
47
+ 'sub': None,
48
+ 'dst': 'checkpoints/genstereo-v2.1',
49
  'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
50
  'token': None
51
  },
 
92
  return dam2
93
 
94
  # GenStereo
95
+ def get_genstereo_model(sd_version):
96
+ genstereo_cfg = dict(
97
  pretrained_model_path='checkpoints',
98
  checkpoint_name=CHECKPOINT_NAME,
99
  half_precision_weights=True
100
  )
101
+ genstereo = GenStereo(cfg=genstereo_cfg, device=DEVICE, sd_version=sd_version)
102
  return genstereo
103
 
104
  # Adaptive Fusion
 
134
  src_depth = gr.State()
135
 
136
  # Callbacks
137
+ def cb_update_sd_version(sd_version_choice):
138
+ global IMAGE_SIZE, CHECKPOINT_NAME
139
+ if sd_version_choice == "v1.5":
140
+ IMAGE_SIZE = 512
141
+ CHECKPOINT_NAME = 'genstereo-v1.5'
142
+ print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
143
+ elif sd_version_choice == "v2.1":
144
+ IMAGE_SIZE = 768
145
+ CHECKPOINT_NAME = 'genstereo-v2.1'
146
+ print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
147
+ return None, None, None, None, None, None
148
+
149
+ def cb_mde(image_file: str, sd_version):
150
  if not image_file:
151
  # Return None if no image is provided (e.g., when file is cleared).
152
  return None, None, None, None
153
 
154
  image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
155
+ if sd_version == "v1.5":
156
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
157
+ elif sd_version == "v2.1":
158
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
159
+ else:
160
+ gr.Warning(f"Unknown SD version: {sd_version}. Defaulting to {IMAGE_SIZE}.")
161
+ image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
162
+ gr.Info(f"Generating with GenStereo {sd_version} at {IMAGE_SIZE}px resolution.")
163
 
164
  image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
165
 
 
171
 
172
  return image, depth_image, image, depth
173
 
174
+ def cb_generate(image, depth: Tensor, scale_factor, sd_version):
 
175
  norm_disp = normalize_disp(depth.cuda())
176
  disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
177
 
178
+ genstereo = get_genstereo_model(sd_version)
179
  fusion_model = get_fusion_model()
180
 
181
  renders = genstereo(
 
197
  # Blocks.
198
  gr.Markdown(
199
  """
200
+ # StereoGen: Towards Open-World Generation of Stereo Images and Unsupervised Matching
201
+ [![Project Site](https://img.shields.io/badge/Project-Web-green)](https://qjizhi.github.io/genstereo)  
202
+ [![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/FQiao/GenStereo)  
203
+ [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/Qjizhi/GenStereo)  
204
+ [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/FQiao/GenStereo/tree/main)  
205
+ [![arXiv](https://img.shields.io/badge/arXiv-2503.12720-red?logo=arxiv)](https://arxiv.org/abs/2503.12720)
206
+
207
  ## Introduction
208
+ This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
209
+
210
  ## How to Use
211
+
212
+ 1. Select the GenStereo version
213
+ - v1.5: 512px, faster.
214
+ - v2.1: 768px, better performance, high resolution, takes more time.
215
+ 2. Upload a reference image to "Left Image"
216
+ - You can also select an image from "Examples"
217
+ 3. Hit "Generate a right image" button and check the result.
218
+
219
  """
220
  )
221
+
222
+ sd_version_radio = gr.Radio(
223
+ label="GenStereo Version",
224
+ choices=["v1.5", "v2.1"],
225
+ value="v2.1",
 
 
226
  )
227
+
228
+ with gr.Row():
229
+
230
+ file = gr.File(label='Left', file_types=['image'])
231
+ examples = gr.Examples(
232
+ examples=['./assets/COCO_val2017_000000070229.jpg',
233
+ './assets/COCO_val2017_000000092839.jpg',
234
+ './assets/KITTI2015_000003_10.png',
235
+ './assets/KITTI2015_000147_10.png'],
236
+ inputs=file
237
+ )
238
  with gr.Row():
239
  image_widget = gr.Image(
240
  label='Left Image', type='filepath',
 
261
  )
262
 
263
  # Events
264
+ sd_version_radio.change(
265
+ fn=cb_update_sd_version,
266
+ inputs=sd_version_radio,
267
+ outputs=[
268
+ image_widget, depth_widget, # Clear image displays
269
+ src_image, src_depth, # Clear internal states
270
+ warped_widget, gen_widget # Clear generation outputs
271
+ ]
272
+ )
273
  file.change(
274
  fn=cb_mde,
275
+ inputs=[file, sd_version_radio],
276
  outputs=[image_widget, depth_widget, src_image, src_depth]
277
  )
278
  button.click(
279
  fn=cb_generate,
280
+ inputs=[src_image, src_depth, scale_slider, sd_version_radio],
281
  outputs=[warped_widget, gen_widget]
282
  )
283