Update app.py
Browse files
app.py
CHANGED
@@ -11,14 +11,13 @@ from torch import Tensor
|
|
11 |
from genstereo import GenStereo, AdaptiveFusionLayer
|
12 |
import ssl
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
-
import spaces
|
15 |
|
16 |
from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
|
17 |
ssl._create_default_https_context = ssl._create_unverified_context
|
18 |
|
19 |
-
IMAGE_SIZE =
|
20 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
21 |
-
CHECKPOINT_NAME = 'genstereo'
|
22 |
|
23 |
def download_models():
|
24 |
models = [
|
@@ -39,7 +38,14 @@ def download_models():
|
|
39 |
{
|
40 |
'repo': 'FQiao/GenStereo',
|
41 |
'sub': None,
|
42 |
-
'dst': 'checkpoints/genstereo',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
|
44 |
'token': None
|
45 |
},
|
@@ -86,13 +92,13 @@ def get_dam2_model():
|
|
86 |
return dam2
|
87 |
|
88 |
# GenStereo
|
89 |
-
def get_genstereo_model():
|
90 |
-
|
91 |
pretrained_model_path='checkpoints',
|
92 |
checkpoint_name=CHECKPOINT_NAME,
|
93 |
half_precision_weights=True
|
94 |
)
|
95 |
-
genstereo = GenStereo(cfg=
|
96 |
return genstereo
|
97 |
|
98 |
# Adaptive Fusion
|
@@ -128,14 +134,32 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
128 |
src_depth = gr.State()
|
129 |
|
130 |
# Callbacks
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
if not image_file:
|
134 |
# Return None if no image is provided (e.g., when file is cleared).
|
135 |
return None, None, None, None
|
136 |
|
137 |
image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
141 |
|
@@ -147,12 +171,11 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
147 |
|
148 |
return image, depth_image, image, depth
|
149 |
|
150 |
-
|
151 |
-
def cb_generate(image, depth: Tensor, scale_factor):
|
152 |
norm_disp = normalize_disp(depth.cuda())
|
153 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
154 |
|
155 |
-
genstereo = get_genstereo_model()
|
156 |
fusion_model = get_fusion_model()
|
157 |
|
158 |
renders = genstereo(
|
@@ -174,27 +197,44 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
174 |
# Blocks.
|
175 |
gr.Markdown(
|
176 |
"""
|
177 |
-
#
|
178 |
-
|
179 |
-
[
|
180 |
-
|
|
|
|
|
|
|
181 |
## Introduction
|
182 |
-
This is an official demo for the paper "
|
183 |
-
|
184 |
## How to Use
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
188 |
"""
|
189 |
)
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
'./assets/KITTI2015_000147_10.png'],
|
196 |
-
inputs=file
|
197 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
with gr.Row():
|
199 |
image_widget = gr.Image(
|
200 |
label='Left Image', type='filepath',
|
@@ -221,14 +261,23 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
221 |
)
|
222 |
|
223 |
# Events
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
file.change(
|
225 |
fn=cb_mde,
|
226 |
-
inputs=file,
|
227 |
outputs=[image_widget, depth_widget, src_image, src_depth]
|
228 |
)
|
229 |
button.click(
|
230 |
fn=cb_generate,
|
231 |
-
inputs=[src_image, src_depth, scale_slider],
|
232 |
outputs=[warped_widget, gen_widget]
|
233 |
)
|
234 |
|
|
|
11 |
from genstereo import GenStereo, AdaptiveFusionLayer
|
12 |
import ssl
|
13 |
from huggingface_hub import hf_hub_download
|
|
|
14 |
|
15 |
from extern.DAM2.depth_anything_v2.dpt import DepthAnythingV2
|
16 |
ssl._create_default_https_context = ssl._create_unverified_context
|
17 |
|
18 |
+
IMAGE_SIZE = 768
|
19 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
20 |
+
CHECKPOINT_NAME = 'genstereo-v2.1'
|
21 |
|
22 |
def download_models():
|
23 |
models = [
|
|
|
38 |
{
|
39 |
'repo': 'FQiao/GenStereo',
|
40 |
'sub': None,
|
41 |
+
'dst': 'checkpoints/genstereo-v1.5',
|
42 |
+
'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
|
43 |
+
'token': None
|
44 |
+
},
|
45 |
+
{
|
46 |
+
'repo': 'FQiao/GenStereo-sd2.1',
|
47 |
+
'sub': None,
|
48 |
+
'dst': 'checkpoints/genstereo-v2.1',
|
49 |
'files': ['config.json', 'denoising_unet.pth', 'fusion_layer.pth', 'pose_guider.pth', 'reference_unet.pth'],
|
50 |
'token': None
|
51 |
},
|
|
|
92 |
return dam2
|
93 |
|
94 |
# GenStereo
|
95 |
+
def get_genstereo_model(sd_version):
|
96 |
+
genstereo_cfg = dict(
|
97 |
pretrained_model_path='checkpoints',
|
98 |
checkpoint_name=CHECKPOINT_NAME,
|
99 |
half_precision_weights=True
|
100 |
)
|
101 |
+
genstereo = GenStereo(cfg=genstereo_cfg, device=DEVICE, sd_version=sd_version)
|
102 |
return genstereo
|
103 |
|
104 |
# Adaptive Fusion
|
|
|
134 |
src_depth = gr.State()
|
135 |
|
136 |
# Callbacks
|
137 |
+
def cb_update_sd_version(sd_version_choice):
|
138 |
+
global IMAGE_SIZE, CHECKPOINT_NAME
|
139 |
+
if sd_version_choice == "v1.5":
|
140 |
+
IMAGE_SIZE = 512
|
141 |
+
CHECKPOINT_NAME = 'genstereo-v1.5'
|
142 |
+
print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
|
143 |
+
elif sd_version_choice == "v2.1":
|
144 |
+
IMAGE_SIZE = 768
|
145 |
+
CHECKPOINT_NAME = 'genstereo-v2.1'
|
146 |
+
print(f"Switched to GenStereo {sd_version_choice}. IMAGE_SIZE: {IMAGE_SIZE}, CHECKPOINT: {CHECKPOINT_NAME}")
|
147 |
+
return None, None, None, None, None, None
|
148 |
+
|
149 |
+
def cb_mde(image_file: str, sd_version):
|
150 |
if not image_file:
|
151 |
# Return None if no image is provided (e.g., when file is cleared).
|
152 |
return None, None, None, None
|
153 |
|
154 |
image = crop(Image.open(image_file).convert('RGB')) # Load image using PIL
|
155 |
+
if sd_version == "v1.5":
|
156 |
+
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
|
157 |
+
elif sd_version == "v2.1":
|
158 |
+
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
|
159 |
+
else:
|
160 |
+
gr.Warning(f"Unknown SD version: {sd_version}. Defaulting to {IMAGE_SIZE}.")
|
161 |
+
image = image.resize((IMAGE_SIZE, IMAGE_SIZE))
|
162 |
+
gr.Info(f"Generating with GenStereo {sd_version} at {IMAGE_SIZE}px resolution.")
|
163 |
|
164 |
image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
165 |
|
|
|
171 |
|
172 |
return image, depth_image, image, depth
|
173 |
|
174 |
+
def cb_generate(image, depth: Tensor, scale_factor, sd_version):
|
|
|
175 |
norm_disp = normalize_disp(depth.cuda())
|
176 |
disp = norm_disp * scale_factor / 100 * IMAGE_SIZE
|
177 |
|
178 |
+
genstereo = get_genstereo_model(sd_version)
|
179 |
fusion_model = get_fusion_model()
|
180 |
|
181 |
renders = genstereo(
|
|
|
197 |
# Blocks.
|
198 |
gr.Markdown(
|
199 |
"""
|
200 |
+
# StereoGen: Towards Open-World Generation of Stereo Images and Unsupervised Matching
|
201 |
+
[](https://qjizhi.github.io/genstereo)
|
202 |
+
[](https://huggingface.co/spaces/FQiao/GenStereo)
|
203 |
+
[](https://github.com/Qjizhi/GenStereo)
|
204 |
+
[](https://huggingface.co/FQiao/GenStereo/tree/main)
|
205 |
+
[](https://arxiv.org/abs/2503.12720)
|
206 |
+
|
207 |
## Introduction
|
208 |
+
This is an official demo for the paper "[Towards Open-World Generation of Stereo Images and Unsupervised Matching](https://qjizhi.github.io/genstereo)". Given an arbitrary reference image, GenStereo can generate the corresponding right-view image.
|
209 |
+
|
210 |
## How to Use
|
211 |
+
|
212 |
+
1. Select the GenStereo version
|
213 |
+
- v1.5: 512px, faster.
|
214 |
+
- v2.1: 768px, better performance, high resolution, takes more time.
|
215 |
+
2. Upload a reference image to "Left Image"
|
216 |
+
- You can also select an image from "Examples"
|
217 |
+
3. Hit "Generate a right image" button and check the result.
|
218 |
+
|
219 |
"""
|
220 |
)
|
221 |
+
|
222 |
+
sd_version_radio = gr.Radio(
|
223 |
+
label="GenStereo Version",
|
224 |
+
choices=["v1.5", "v2.1"],
|
225 |
+
value="v2.1",
|
|
|
|
|
226 |
)
|
227 |
+
|
228 |
+
with gr.Row():
|
229 |
+
|
230 |
+
file = gr.File(label='Left', file_types=['image'])
|
231 |
+
examples = gr.Examples(
|
232 |
+
examples=['./assets/COCO_val2017_000000070229.jpg',
|
233 |
+
'./assets/COCO_val2017_000000092839.jpg',
|
234 |
+
'./assets/KITTI2015_000003_10.png',
|
235 |
+
'./assets/KITTI2015_000147_10.png'],
|
236 |
+
inputs=file
|
237 |
+
)
|
238 |
with gr.Row():
|
239 |
image_widget = gr.Image(
|
240 |
label='Left Image', type='filepath',
|
|
|
261 |
)
|
262 |
|
263 |
# Events
|
264 |
+
sd_version_radio.change(
|
265 |
+
fn=cb_update_sd_version,
|
266 |
+
inputs=sd_version_radio,
|
267 |
+
outputs=[
|
268 |
+
image_widget, depth_widget, # Clear image displays
|
269 |
+
src_image, src_depth, # Clear internal states
|
270 |
+
warped_widget, gen_widget # Clear generation outputs
|
271 |
+
]
|
272 |
+
)
|
273 |
file.change(
|
274 |
fn=cb_mde,
|
275 |
+
inputs=[file, sd_version_radio],
|
276 |
outputs=[image_widget, depth_widget, src_image, src_depth]
|
277 |
)
|
278 |
button.click(
|
279 |
fn=cb_generate,
|
280 |
+
inputs=[src_image, src_depth, scale_slider, sd_version_radio],
|
281 |
outputs=[warped_widget, gen_widget]
|
282 |
)
|
283 |
|