kxic commited on
Commit
34eb8c0
·
verified ·
1 Parent(s): c8edbc1

Delete gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +0 -733
gradio_demo.py DELETED
@@ -1,733 +0,0 @@
1
- import gradio as gr
2
- import os
3
- import shutil
4
- import rembg
5
- import numpy as np
6
- import math
7
- import open3d as o3d
8
- from PIL import Image
9
- import torch
10
- import torchvision
11
- import trimesh
12
- from skimage.io import imsave
13
- import imageio
14
- import cv2
15
- import matplotlib.pyplot as pl
16
- pl.ion()
17
-
18
-
19
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
- weight_dtype = torch.float16
21
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
22
-
23
- # EscherNet
24
- # create angles in archimedean spiral with N steps
25
- def get_archimedean_spiral(sphere_radius, num_steps=250):
26
- # x-z plane, around upper y
27
- '''
28
- https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
29
- '''
30
- a = 40
31
- r = sphere_radius
32
-
33
- translations = []
34
- angles = []
35
-
36
- # i = a / 2
37
- i = 0.01
38
- while i < a:
39
- theta = i / a * math.pi
40
- x = r * math.sin(theta) * math.cos(-i)
41
- z = r * math.sin(-theta + math.pi) * math.sin(-i)
42
- y = r * - math.cos(theta)
43
-
44
- # translations.append((x, y, z)) # origin
45
- translations.append((x, z, -y))
46
- angles.append([np.rad2deg(-i), np.rad2deg(theta)])
47
-
48
- # i += a / (2 * num_steps)
49
- i += a / (1 * num_steps)
50
-
51
- return np.array(translations), np.stack(angles)
52
-
53
- def look_at(origin, target, up):
54
- forward = (target - origin)
55
- forward = forward / np.linalg.norm(forward)
56
- right = np.cross(up, forward)
57
- right = right / np.linalg.norm(right)
58
- new_up = np.cross(forward, right)
59
- rotation_matrix = np.column_stack((right, new_up, -forward, target))
60
- matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
61
- return matrix
62
-
63
- CaPE_TYPE = "6DoF"
64
- import einops
65
- if CaPE_TYPE == "6DoF":
66
- import sys
67
-
68
- sys.path.insert(0, "../6DoF/")
69
- # use the customized diffusers modules
70
- from diffusers import DDIMScheduler
71
- from dataset import get_pose
72
- from CN_encoder import CN_encoder
73
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
74
-
75
- elif CaPE_TYPE == "4DoF":
76
- import sys
77
-
78
- sys.path.insert(0, "../4DoF/")
79
- # use the customized diffusers modules
80
- from diffusers import DDIMScheduler
81
- from dataset import get_pose
82
- from CN_encoder import CN_encoder
83
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
84
- else:
85
- raise ValueError("CaPE_TYPE must be chosen from 4DoF, 6DoF")
86
-
87
-
88
- pretrained_model_name_or_path = "XY-Xin/N3M3B112G6_6dof_36k" # TODO
89
- resolution = 256
90
- h,w = resolution,resolution
91
- guidance_scale = 3.0
92
- radius = 2.2
93
- bg_color = [1., 1., 1., 1.]
94
- image_transforms = torchvision.transforms.Compose(
95
- [
96
- torchvision.transforms.Resize((resolution, resolution)), # 256, 256
97
- torchvision.transforms.ToTensor(),
98
- torchvision.transforms.Normalize([0.5], [0.5])
99
- ]
100
- )
101
- xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
102
- # only half toop
103
- xyzs_spiral = xyzs_spiral[:100]
104
- angles_spiral = angles_spiral[:100]
105
-
106
- # Init pipeline
107
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
108
- image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
109
- pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
110
- pretrained_model_name_or_path,
111
- revision=None,
112
- scheduler=scheduler,
113
- image_encoder=None,
114
- safety_checker=None,
115
- feature_extractor=None,
116
- torch_dtype=weight_dtype,
117
- )
118
- pipeline.image_encoder = image_encoder
119
- pipeline = pipeline.to(device)
120
- pipeline.set_progress_bar_config(disable=False)
121
-
122
- pipeline.enable_xformers_memory_efficient_attention()
123
- # enable vae slicing
124
- pipeline.enable_vae_slicing()
125
-
126
-
127
-
128
-
129
-
130
- def run_eschernet(tmpdirname, eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
131
- # set the random seed
132
- generator = torch.Generator(device=device).manual_seed(sample_seed)
133
- T_out = nvs_num
134
- T_in = len(eschernet_input_dict['imgs'])
135
- ####### output pose
136
- # TODO choose T_out number of poses sequentially from the spiral
137
- xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
138
- angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
139
-
140
- ####### input's max radius for translation scaling
141
- radii = eschernet_input_dict['radii']
142
- max_t = np.max(radii)
143
- min_t = np.min(radii)
144
-
145
- ####### input pose
146
- pose_in = []
147
- for T_in_index in range(T_in):
148
- pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
149
- pose[1:3, :] *= -1 # coordinate system conversion
150
- pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
151
- pose_in.append(torch.from_numpy(pose))
152
-
153
- ####### input image
154
- img = eschernet_input_dict['imgs'] / 255.
155
- img[img[:, :, :, -1] == 0.] = bg_color
156
- # TODO batch image_transforms
157
- input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
158
-
159
- ####### nvs pose
160
- pose_out = []
161
- for T_out_index in range(T_out):
162
- azimuth, polar = angles_out[T_out_index]
163
- if CaPE_TYPE == "4DoF":
164
- pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
165
- elif CaPE_TYPE == "6DoF":
166
- pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
167
- pose = np.linalg.inv(pose)
168
- pose[2, :] *= -1
169
- pose_out.append(torch.from_numpy(get_pose(pose)))
170
-
171
-
172
-
173
- # [B, T, C, H, W]
174
- input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
175
- # [B, T, 4]
176
- pose_in = np.stack(pose_in)
177
- pose_out = np.stack(pose_out)
178
-
179
- if CaPE_TYPE == "6DoF":
180
- pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
181
- pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
182
- pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
183
- pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
184
-
185
- pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
186
- pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
187
-
188
- input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
189
- assert T_in == input_image.shape[0]
190
- assert T_in == pose_in.shape[1]
191
- assert T_out == pose_out.shape[1]
192
-
193
- # run inference
194
- if CaPE_TYPE == "6DoF":
195
- with torch.autocast("cuda"):
196
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
197
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
198
- height=h, width=w, T_in=T_in, T_out=T_out,
199
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
200
- output_type="numpy").images
201
- elif CaPE_TYPE == "4DoF":
202
- with torch.autocast("cuda"):
203
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image, poses=[pose_out, pose_in],
204
- height=h, width=w, T_in=T_in, T_out=T_out,
205
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
206
- output_type="numpy").images
207
-
208
- # save output image
209
- output_dir = os.path.join(tmpdirname, "eschernet")
210
- if os.path.exists(output_dir):
211
- shutil.rmtree(output_dir)
212
- os.makedirs(output_dir, exist_ok=True)
213
- # save to N imgs
214
- for i in range(T_out):
215
- imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
216
- # make a gif
217
- frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
218
- frame_one = frames[0]
219
- frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
220
- save_all=True, duration=50, loop=1)
221
-
222
- # get a video
223
- video_path = os.path.join(output_dir, "output.mp4")
224
- imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
225
-
226
-
227
- return image, video_path
228
-
229
- # TODO mesh it
230
- def make3d():
231
- pass
232
-
233
-
234
-
235
- ############################ Dust3r as Pose Estimation ############################
236
- from scipy.spatial.transform import Rotation
237
- import copy
238
-
239
- from dust3r.inference import inference
240
- from dust3r.model import AsymmetricCroCo3DStereo
241
- from dust3r.image_pairs import make_pairs
242
- from dust3r.utils.image import load_images, rgb
243
- from dust3r.utils.device import to_numpy
244
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
245
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
246
-
247
- import functools
248
- import math
249
-
250
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
251
- cam_color=None, as_pointcloud=False,
252
- transparent_cams=False, silent=False, same_focals=False):
253
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
254
- if not same_focals:
255
- assert (len(cams2world) == len(focals))
256
- pts3d = to_numpy(pts3d)
257
- imgs = to_numpy(imgs)
258
- focals = to_numpy(focals)
259
- cams2world = to_numpy(cams2world)
260
-
261
- scene = trimesh.Scene()
262
-
263
- # add axes
264
- scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
265
-
266
- # full pointcloud
267
- if as_pointcloud:
268
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
269
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
270
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
271
- scene.add_geometry(pct)
272
- else:
273
- meshes = []
274
- for i in range(len(imgs)):
275
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
276
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
277
- scene.add_geometry(mesh)
278
-
279
- # add each camera
280
- for i, pose_c2w in enumerate(cams2world):
281
- if isinstance(cam_color, list):
282
- camera_edge_color = cam_color[i]
283
- else:
284
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
285
- if same_focals:
286
- focal = focals[0]
287
- else:
288
- focal = focals[i]
289
- add_scene_cam(scene, pose_c2w, camera_edge_color,
290
- None if transparent_cams else imgs[i], focal,
291
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
292
-
293
- rot = np.eye(4)
294
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
295
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
296
- outfile = os.path.join(outdir, 'scene.glb')
297
- if not silent:
298
- print('(exporting 3D scene to', outfile, ')')
299
- scene.export(file_obj=outfile)
300
- return outfile
301
-
302
-
303
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
304
- clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
305
- """
306
- extract 3D_model (glb file) from a reconstructed scene
307
- """
308
- if scene is None:
309
- return None
310
- # post processes
311
- if clean_depth:
312
- scene = scene.clean_pointcloud()
313
- if mask_sky:
314
- scene = scene.mask_sky()
315
-
316
- # get optimized values from scene
317
- rgbimg = to_numpy(scene.imgs)
318
- focals = to_numpy(scene.get_focals().cpu())
319
- # cams2world = to_numpy(scene.get_im_poses().cpu())
320
- # TODO use the vis_poses
321
- cams2world = scene.vis_poses
322
-
323
- # 3D pointcloud from depthmap, poses and intrinsics
324
- # pts3d = to_numpy(scene.get_pts3d())
325
- # TODO use the vis_poses
326
- pts3d = scene.vis_pts3d
327
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
328
- msk = to_numpy(scene.get_masks())
329
-
330
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
331
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
332
- same_focals=same_focals)
333
-
334
-
335
- def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
336
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
337
- scenegraph_type, winsize, refid, same_focals):
338
- """
339
- from a list of images, run dust3r inference, global aligner.
340
- then run get_3D_model_from_scene
341
- """
342
- # remove the directory if it already exists
343
- if os.path.exists(outdir):
344
- shutil.rmtree(outdir)
345
- os.makedirs(outdir, exist_ok=True)
346
- imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True)
347
- if len(imgs) == 1:
348
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
349
- imgs[1]['idx'] = 1
350
- if scenegraph_type == "swin":
351
- scenegraph_type = scenegraph_type + "-" + str(winsize)
352
- elif scenegraph_type == "oneref":
353
- scenegraph_type = scenegraph_type + "-" + str(refid)
354
-
355
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
356
- output = inference(pairs, model, device, batch_size=1, verbose=not silent)
357
-
358
- mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
359
- scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
360
- lr = 0.01
361
-
362
- if mode == GlobalAlignerMode.PointCloudOptimizer:
363
- loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
364
-
365
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
366
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
367
-
368
- # also return rgb, depth and confidence imgs
369
- # depth is normalized with the max value for all images
370
- # we apply the jet colormap on the confidence maps
371
- rgbimg = scene.imgs
372
- # depths = to_numpy(scene.get_depthmaps())
373
- # confs = to_numpy([c for c in scene.im_conf])
374
- # cmap = pl.get_cmap('jet')
375
- # depths_max = max([d.max() for d in depths])
376
- # depths = [d / depths_max for d in depths]
377
- # confs_max = max([d.max() for d in confs])
378
- # confs = [cmap(d / confs_max) for d in confs]
379
-
380
- imgs = []
381
- rgbaimg = []
382
- for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
383
- imgs.append(rgbimg[i])
384
- # imgs.append(rgb(depths[i]))
385
- # imgs.append(rgb(confs[i]))
386
- # imgs.append(imgs_rgba[i])
387
- if len(imgs_rgba) == 1 and i == 1:
388
- imgs.append(imgs_rgba[0])
389
- rgbaimg.append(np.array(imgs_rgba[0]))
390
- else:
391
- imgs.append(imgs_rgba[i])
392
- rgbaimg.append(np.array(imgs_rgba[i]))
393
-
394
- rgbaimg = np.array(rgbaimg)
395
-
396
- # for eschernet
397
- # get optimized values from scene
398
- rgbimg = to_numpy(scene.imgs)
399
- focals = to_numpy(scene.get_focals().cpu())
400
- cams2world = to_numpy(scene.get_im_poses().cpu())
401
-
402
- # 3D pointcloud from depthmap, poses and intrinsics
403
- pts3d = to_numpy(scene.get_pts3d())
404
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
405
- msk = to_numpy(scene.get_masks())
406
- obj_mask = rgbaimg[..., 3] > 0
407
-
408
- # TODO set global coordinate system at the center of the scene, z-axis is up
409
- pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
410
- pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
411
- centroid = np.mean(pts_obj, axis=0) # obj center
412
- obj2world = np.eye(4)
413
- obj2world[:3, 3] = -centroid # T_wc
414
-
415
- # get z_up vector
416
- # TODO fit a plane and get the normal vector
417
- pcd = o3d.geometry.PointCloud()
418
- pcd.points = o3d.utility.Vector3dVector(pts)
419
- plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
420
- # get the normalised normal vector dim = 3
421
- normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
422
- # the normal direction should be pointing up
423
- if normal[1] < 0:
424
- normal = -normal
425
- # print("normal", normal)
426
-
427
- # # TODO z-up 180
428
- # z_up = np.array([[1,0,0,0],
429
- # [0,-1,0,0],
430
- # [0,0,-1,0],
431
- # [0,0,0,1]])
432
- # obj2world = z_up @ obj2world
433
-
434
- # # avg the y
435
- # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
436
- # # import pdb; pdb.set_trace()
437
- # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
438
- # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
439
- # rot = Rotation.from_rotvec(rot_angle * rot_axis)
440
- # z_up = np.eye(4)
441
- # z_up[:3, :3] = rot.as_matrix()
442
-
443
- # get the rotation matrix from normal to z-axis
444
- z_axis = np.array([0, 0, 1])
445
- rot_axis = np.cross(normal, z_axis)
446
- rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
447
- rot = Rotation.from_rotvec(rot_angle * rot_axis)
448
- z_up = np.eye(4)
449
- z_up[:3, :3] = rot.as_matrix()
450
- obj2world = z_up @ obj2world
451
- # flip 180
452
- flip_rot = np.array([[1, 0, 0, 0],
453
- [0, -1, 0, 0],
454
- [0, 0, -1, 0],
455
- [0, 0, 0, 1]])
456
- obj2world = flip_rot @ obj2world
457
-
458
- # get new cams2obj
459
- cams2obj = []
460
- for i, cam2world in enumerate(cams2world):
461
- cams2obj.append(obj2world @ cam2world)
462
- # TODO transform pts3d to the new coordinate system
463
- for i, pts in enumerate(pts3d):
464
- pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
465
- -1)) \
466
- .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
467
- cams2world = np.array(cams2obj)
468
- # TODO rewrite hack
469
- scene.vis_poses = cams2world.copy()
470
- scene.vis_pts3d = pts3d.copy()
471
-
472
- # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
473
- for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
474
- np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
475
- pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
476
- pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
477
- # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
478
- # save the min/max radius of camera
479
- radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
480
- np.save(os.path.join(outdir, "radii.npy"), radii)
481
-
482
- eschernet_input = {"poses": cams2world,
483
- "radii": radii,
484
- "imgs": rgbaimg}
485
-
486
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
487
- clean_depth, transparent_cams, cam_size, same_focals=same_focals)
488
-
489
- return scene, outfile, imgs, eschernet_input
490
-
491
-
492
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
493
- num_files = len(inputfiles) if inputfiles is not None else 1
494
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
495
- if scenegraph_type == "swin":
496
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
497
- minimum=1, maximum=max_winsize, step=1, visible=True)
498
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
499
- maximum=num_files - 1, step=1, visible=False)
500
- elif scenegraph_type == "oneref":
501
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
502
- minimum=1, maximum=max_winsize, step=1, visible=False)
503
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
504
- maximum=num_files - 1, step=1, visible=True)
505
- else:
506
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
507
- minimum=1, maximum=max_winsize, step=1, visible=False)
508
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
509
- maximum=num_files - 1, step=1, visible=False)
510
- return winsize, refid
511
-
512
-
513
-
514
-
515
-
516
- def main():
517
- # dustr init
518
- silent = False
519
- image_size = 224
520
- weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
521
- model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
522
- # dust3r will write the 3D model inside tmpdirname
523
- # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
524
- tmpdirname = os.path.join('logs/user_object')
525
- # remove the directory if it already exists
526
- if os.path.exists(tmpdirname):
527
- shutil.rmtree(tmpdirname)
528
- os.makedirs(tmpdirname, exist_ok=True)
529
- if not silent:
530
- print('Outputing stuff in', tmpdirname)
531
-
532
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
533
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
534
-
535
- generate_mvs = functools.partial(run_eschernet, tmpdirname)
536
-
537
- _HEADER_ = '''
538
- '''
539
-
540
- _CITE_ = r"""
541
- """
542
-
543
- with gr.Blocks() as demo:
544
- gr.Markdown(_HEADER_)
545
- mv_images = gr.State()
546
- scene = gr.State(None)
547
- eschernet_input = gr.State(None)
548
- with gr.Row(variant="panel"):
549
- # left column
550
- with gr.Column():
551
- with gr.Row():
552
- input_image = gr.File(file_count="multiple")
553
- with gr.Row():
554
- run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
555
- with gr.Row():
556
- processed_image = gr.Gallery(label='rgb,rgba', columns=2, height="100%")
557
-
558
-
559
-
560
-
561
-
562
- # with gr.Row(variant="panel"):
563
- # gr.Examples(
564
- # examples=[
565
- # os.path.join("examples/hairdryer", img_name) for img_name in sorted(os.listdir("examples/hairdryer"))
566
- # ],
567
- # inputs=[input_image],
568
- # label="Examples",
569
- # examples_per_page=20
570
- # )
571
-
572
- # right column
573
- with gr.Column():
574
-
575
- with gr.Row():
576
- outmodel = gr.Model3D()
577
-
578
- with gr.Row():
579
- gr.Markdown('''Check if the pose and segmentation looks correct. If not, remove the incorrect images and try again.''')
580
-
581
- with gr.Row():
582
- with gr.Group():
583
- do_remove_background = gr.Checkbox(
584
- label="Remove Background", value=True
585
- )
586
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
587
-
588
- sample_steps = gr.Slider(
589
- label="Sample Steps",
590
- minimum=30,
591
- maximum=75,
592
- value=50,
593
- step=5,
594
- visible=False
595
- )
596
-
597
- nvs_num = gr.Slider(
598
- label="Number of Novel Views",
599
- minimum=5,
600
- maximum=100,
601
- value=30,
602
- step=1
603
- )
604
-
605
- nvs_mode = gr.Dropdown(["archimedes circle", "fixed 4 views", "fixed 8 views"],
606
- value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
607
-
608
- with gr.Row():
609
- gr.Markdown('''Choose your desired novel view poses and generate!''')
610
-
611
- with gr.Row():
612
- submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
613
-
614
- with gr.Row():
615
- # mv_show_images = gr.Image(
616
- # label="Generated Multi-views",
617
- # type="pil",
618
- # width=379,
619
- # interactive=False
620
- # )
621
- with gr.Column():
622
- output_video = gr.Video(
623
- label="video", format="mp4",
624
- width=379,
625
- autoplay=True,
626
- interactive=False
627
- )
628
-
629
- # with gr.Row():
630
- # with gr.Tab("OBJ"):
631
- # output_model_obj = gr.Model3D(
632
- # label="Output Model (OBJ Format)",
633
- # #width=768,
634
- # interactive=False,
635
- # )
636
- # gr.Markdown("Note: Downloaded .obj model will be flipped. Export .glb instead or manually flip it before usage.")
637
- # with gr.Tab("GLB"):
638
- # output_model_glb = gr.Model3D(
639
- # label="Output Model (GLB Format)",
640
- # #width=768,
641
- # interactive=False,
642
- # )
643
- # gr.Markdown("Note: The model shown here has a darker appearance. Download to get correct results.")
644
-
645
- with gr.Row():
646
- gr.Markdown('''Try a different <b>seed value</b> if the result is unsatisfying (Default: 42).''')
647
-
648
- gr.Markdown(_CITE_)
649
-
650
- # set dust3r parameter invisible to be clean
651
- with gr.Column():
652
- with gr.Row():
653
- schedule = gr.Dropdown(["linear", "cosine"],
654
- value='linear', label="schedule", info="For global alignment!", visible=False)
655
- niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
656
- label="num_iterations", info="For global alignment!", visible=False)
657
- scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
658
- value='complete', label="Scenegraph",
659
- info="Define how to make pairs",
660
- interactive=True, visible=False)
661
- same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
662
- winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
663
- minimum=1, maximum=1, step=1, visible=False)
664
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
665
-
666
- with gr.Row():
667
- # adjust the confidence threshold
668
- min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
669
- # adjust the camera size in the output pointcloud
670
- cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
671
- with gr.Row():
672
- as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
673
- # two post process implemented
674
- mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
675
- clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
676
- transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
677
-
678
- # events
679
- # scenegraph_type.change(set_scenegraph_options,
680
- # inputs=[input_image, winsize, refid, scenegraph_type],
681
- # outputs=[winsize, refid])
682
- input_image.change(set_scenegraph_options,
683
- inputs=[input_image, winsize, refid, scenegraph_type],
684
- outputs=[winsize, refid])
685
- # min_conf_thr.release(fn=model_from_scene_fun,
686
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
687
- # clean_depth, transparent_cams, cam_size, same_focals],
688
- # outputs=outmodel)
689
- # cam_size.change(fn=model_from_scene_fun,
690
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
691
- # clean_depth, transparent_cams, cam_size, same_focals],
692
- # outputs=outmodel)
693
- # as_pointcloud.change(fn=model_from_scene_fun,
694
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
695
- # clean_depth, transparent_cams, cam_size, same_focals],
696
- # outputs=outmodel)
697
- # mask_sky.change(fn=model_from_scene_fun,
698
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
699
- # clean_depth, transparent_cams, cam_size, same_focals],
700
- # outputs=outmodel)
701
- # clean_depth.change(fn=model_from_scene_fun,
702
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
703
- # clean_depth, transparent_cams, cam_size, same_focals],
704
- # outputs=outmodel)
705
- # transparent_cams.change(model_from_scene_fun,
706
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
707
- # clean_depth, transparent_cams, cam_size, same_focals],
708
- # outputs=outmodel)
709
- run_dust3r.click(fn=recon_fun,
710
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
711
- mask_sky, clean_depth, transparent_cams, cam_size,
712
- scenegraph_type, winsize, refid, same_focals],
713
- outputs=[scene, outmodel, processed_image, eschernet_input])
714
-
715
-
716
-
717
- submit.click(fn=generate_mvs,
718
- inputs=[eschernet_input, sample_steps, sample_seed,
719
- nvs_num, nvs_mode],
720
- outputs=[mv_images, output_video],
721
- )#.success(
722
- # # fn=make3d,
723
- # # inputs=[mv_images],
724
- # # outputs=[output_video, output_model_obj, output_model_glb]
725
- # # )
726
-
727
-
728
-
729
- demo.queue(max_size=10)
730
- demo.launch(share=True, server_name="0.0.0.0", server_port=None)
731
-
732
- if __name__ == '__main__':
733
- main()