刘虹雨 commited on
Commit
6dce7fe
·
1 Parent(s): 10ae7b2

update code

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -473,12 +473,12 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
473
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
474
  motion_base_dir = os.path.join(target_path, 'motions')
475
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
476
- render_model.to(device)
477
- image_encoder.to(device)
478
- vae_triplane.to(device)
479
- dinov2.to(device)
480
- ws_avg.to(device)
481
- DiT_model.to(device)
482
  # Set up face verse for amimation
483
  base_coff = np.load(
484
  'pretrained_model/temp.npy').astype(
@@ -526,7 +526,7 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
526
  samples = rearrange(samples, "b c f h w -> b f c h w")
527
  samples = samples * std + mean
528
  torch.cuda.empty_cache()
529
-
530
  save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
531
  save_frames_path_outshow = os.path.join(save_path_base, image_name, 'out_show')
532
  save_frames_path_depth = os.path.join(save_path_base, image_name, 'depth')
@@ -580,22 +580,22 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
580
  # Load motion data
581
  motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
582
 
583
- img_ref_double = duplicate_batch(img_ref, batch_size=2)
584
- motion_app_double = duplicate_batch(motion_app, batch_size=2)
585
- motion_double = duplicate_batch(motion, batch_size=2)
586
- pose_double = torch.cat([pose_show, pose], dim=0)
587
- exp_target_double = duplicate_batch(exp_target, batch_size=2)
588
- samples_double = duplicate_batch(samples, batch_size=2)
589
  # Select refine_net processing method
590
  final_out = render_model(
591
- img_ref_double, None, motion_app_double, motion_double, c=pose_double, mesh=exp_target_double,
592
- triplane_recon=samples_double,
593
  ws_avg=ws_avg, motion_scale=1.
594
  )
595
 
596
  # Process output image
597
  final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
598
- final_out_notshow = trans(final_out['image_sr'][1].unsqueeze(0))
599
  depth = final_out['image_depth'][0].unsqueeze(0)
600
  depth = -depth
601
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1
 
473
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
474
  motion_base_dir = os.path.join(target_path, 'motions')
475
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
476
+ # render_model.to(device)
477
+ # image_encoder.to(device)
478
+ # vae_triplane.to(device)
479
+ # dinov2.to(device)
480
+ # ws_avg.to(device)
481
+ # DiT_model.to(device)
482
  # Set up face verse for amimation
483
  base_coff = np.load(
484
  'pretrained_model/temp.npy').astype(
 
526
  samples = rearrange(samples, "b c f h w -> b f c h w")
527
  samples = samples * std + mean
528
  torch.cuda.empty_cache()
529
+ torch.cuda.ipc_collect()
530
  save_frames_path_out = os.path.join(save_path_base, image_name, 'out')
531
  save_frames_path_outshow = os.path.join(save_path_base, image_name, 'out_show')
532
  save_frames_path_depth = os.path.join(save_path_base, image_name, 'depth')
 
580
  # Load motion data
581
  motion = torch.tensor(np.load(motion_each_dir)).float().unsqueeze(0).to(device)
582
 
583
+ # img_ref_double = duplicate_batch(img_ref, batch_size=2)
584
+ # motion_app_double = duplicate_batch(motion_app, batch_size=2)
585
+ # motion_double = duplicate_batch(motion, batch_size=2)
586
+ # pose_double = torch.cat([pose_show, pose], dim=0)
587
+ # exp_target_double = duplicate_batch(exp_target, batch_size=2)
588
+ # samples_double = duplicate_batch(samples, batch_size=2)
589
  # Select refine_net processing method
590
  final_out = render_model(
591
+ img_ref, None, motion_app, motion, c=pose, mesh=exp_target,
592
+ triplane_recon=samples,
593
  ws_avg=ws_avg, motion_scale=1.
594
  )
595
 
596
  # Process output image
597
  final_out_show = trans(final_out['image_sr'][0].unsqueeze(0))
598
+ final_out_notshow = trans(final_out['image_sr'][0].unsqueeze(0))
599
  depth = final_out['image_depth'][0].unsqueeze(0)
600
  depth = -depth
601
  depth = (depth - depth.min()) / (depth.max() - depth.min()) * 2 - 1