kxhit commited on
Commit
329487d
·
1 Parent(s): 20a47d0

rm partial

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -116,7 +116,7 @@ pipeline.enable_vae_slicing()
116
 
117
 
118
  @spaces.GPU(duration=120)
119
- def run_eschernet(tmpdirname, eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
120
  # set the random seed
121
  generator = torch.Generator(device=device).manual_seed(sample_seed)
122
  T_out = nvs_num
@@ -289,7 +289,7 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
289
  outfile = os.path.join(outdir, 'scene.glb')
290
  if not silent:
291
  print('(exporting 3D scene to', outfile, ')')
292
- # scene.export(file_obj=outfile)
293
  return outfile
294
 
295
  @spaces.GPU(duration=120)
@@ -325,16 +325,19 @@ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud
325
  same_focals=same_focals)
326
 
327
  @spaces.GPU(duration=120)
328
- def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, schedule, niter, min_conf_thr,
329
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
330
  scenegraph_type, winsize, refid, same_focals):
331
  """
332
  from a list of images, run dust3r inference, global aligner.
333
  then run get_3D_model_from_scene
334
  """
 
 
335
  weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
336
  model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
337
  # remove the directory if it already exists
 
338
  if os.path.exists(outdir):
339
  shutil.rmtree(outdir)
340
  os.makedirs(outdir, exist_ok=True)
@@ -541,10 +544,10 @@ os.makedirs(tmpdirname, exist_ok=True)
541
  if not silent:
542
  print('Outputing stuff in', tmpdirname)
543
 
544
- recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
545
- model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
546
 
547
- generate_mvs = functools.partial(run_eschernet, tmpdirname)
548
 
549
  _HEADER_ = '''
550
  <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
@@ -755,11 +758,17 @@ with gr.Blocks() as demo:
755
  # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
756
  # clean_depth, transparent_cams, cam_size, same_focals],
757
  # outputs=outmodel)
758
- run_dust3r.click(fn=recon_fun,
759
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
760
- mask_sky, clean_depth, transparent_cams, cam_size,
761
- scenegraph_type, winsize, refid, same_focals],
762
- outputs=[scene, outmodel, processed_image, eschernet_input])
 
 
 
 
 
 
763
 
764
 
765
  # events
@@ -768,16 +777,22 @@ with gr.Blocks() as demo:
768
  inputs=[input_image],
769
  outputs=[processed_image])
770
 
771
- submit.click(fn=generate_mvs,
772
- inputs=[eschernet_input, sample_steps, sample_seed,
773
- nvs_num, nvs_mode],
774
- outputs=[mv_images, output_video],
775
- )#.success(
776
  # # fn=make3d,
777
  # # inputs=[mv_images],
778
  # # outputs=[output_video, output_model_obj, output_model_glb]
779
  # # )
780
 
 
 
 
 
 
 
781
 
782
 
783
  # demo.queue(max_size=10)
 
116
 
117
 
118
  @spaces.GPU(duration=120)
119
+ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
120
  # set the random seed
121
  generator = torch.Generator(device=device).manual_seed(sample_seed)
122
  T_out = nvs_num
 
289
  outfile = os.path.join(outdir, 'scene.glb')
290
  if not silent:
291
  print('(exporting 3D scene to', outfile, ')')
292
+ scene.export(file_obj=outfile)
293
  return outfile
294
 
295
  @spaces.GPU(duration=120)
 
325
  same_focals=same_focals)
326
 
327
  @spaces.GPU(duration=120)
328
+ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
329
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
330
  scenegraph_type, winsize, refid, same_focals):
331
  """
332
  from a list of images, run dust3r inference, global aligner.
333
  then run get_3D_model_from_scene
334
  """
335
+ silent = False
336
+ image_size = 224
337
  weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
338
  model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
339
  # remove the directory if it already exists
340
+ outdir = tmpdirname
341
  if os.path.exists(outdir):
342
  shutil.rmtree(outdir)
343
  os.makedirs(outdir, exist_ok=True)
 
544
  if not silent:
545
  print('Outputing stuff in', tmpdirname)
546
 
547
+ # recon_fun = functools.partial(get_reconstructed_scene, tmpdirname, model, device, silent, image_size)
548
+ # model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname, silent)
549
 
550
+ # generate_mvs = functools.partial(run_eschernet, tmpdirname)
551
 
552
  _HEADER_ = '''
553
  <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
 
758
  # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
759
  # clean_depth, transparent_cams, cam_size, same_focals],
760
  # outputs=outmodel)
761
+ # run_dust3r.click(fn=recon_fun,
762
+ # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
763
+ # mask_sky, clean_depth, transparent_cams, cam_size,
764
+ # scenegraph_type, winsize, refid, same_focals],
765
+ # outputs=[scene, outmodel, processed_image, eschernet_input])
766
+
767
+ run_dust3r.click(fn=get_reconstructed_scene,
768
+ inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
769
+ mask_sky, clean_depth, transparent_cams, cam_size,
770
+ scenegraph_type, winsize, refid, same_focals],
771
+ outputs=[scene, outmodel, processed_image, eschernet_input],)
772
 
773
 
774
  # events
 
777
  inputs=[input_image],
778
  outputs=[processed_image])
779
 
780
+ # submit.click(fn=generate_mvs,
781
+ # inputs=[eschernet_input, sample_steps, sample_seed,
782
+ # nvs_num, nvs_mode],
783
+ # outputs=[mv_images, output_video],
784
+ # )#.success(
785
  # # fn=make3d,
786
  # # inputs=[mv_images],
787
  # # outputs=[output_video, output_model_obj, output_model_glb]
788
  # # )
789
 
790
+ submit.click(fn=run_eschernet,
791
+ inputs=[eschernet_input, sample_steps, sample_seed,
792
+ nvs_num, nvs_mode],
793
+ outputs=[mv_images, output_video],
794
+ )
795
+
796
 
797
 
798
  # demo.queue(max_size=10)