刘虹雨 commited on
Commit
5834ebe
·
1 Parent(s): af31c35

update code

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -396,13 +396,13 @@ def images_to_video(image_folder, output_video, fps=30):
396
 
397
  print(f"✅ High-quality MP4 video has been generated: {output_video}")
398
 
399
- @spaces.GPU(duration=100)
400
  def model_define():
401
  args = get_args()
402
  set_env(args.seed)
403
  input_process_model = Process(cfg)
404
 
405
- device = "cuda"
406
 
407
  weight_dtype = torch.float32
408
  logging.info(f"Running inference with {weight_dtype}")
@@ -440,18 +440,8 @@ def model_define():
440
  base_coff = torch.from_numpy(base_coff).float()
441
  Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
442
 
443
- controlnet_path = './pretrained_model/control'
444
- controlnet = ControlNetModel.from_pretrained(
445
- controlnet_path, torch_dtype=torch.float16
446
- )
447
- sd_path = './pretrained_model/sd21'
448
- pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
449
- sd_path, torch_dtype=torch.float16,
450
- use_safetensors=True, controlnet=controlnet, variant="fp16"
451
- ).to(device)
452
-
453
  return motion_aware_render_model, sample_steps, DiT_model, \
454
- vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model,pipeline_sd
455
 
456
 
457
  def duplicate_batch(tensor, batch_size=2):
@@ -460,11 +450,8 @@ def duplicate_batch(tensor, batch_size=2):
460
  return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
461
 
462
 
463
- @torch.inference_mode()
464
  @spaces.GPU(duration=200)
465
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
466
-
467
-
468
  """
469
  Generate avatars from input images.
470
 
@@ -491,7 +478,15 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
491
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
492
  motion_base_dir = os.path.join(target_path, 'motions')
493
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
494
-
 
 
 
 
 
 
 
 
495
  if source_type == 'example':
496
  input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
497
  input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
@@ -658,6 +653,7 @@ def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
658
  🎭 这个函数用于风格转换
659
  ✅ 你可以在这里填入你的风格化代码
660
  """
 
661
  src_img_pil = Image.open(processed_image)
662
  img_name = os.path.basename(processed_image)
663
  save_dir = os.path.join(save_base, 'style_img')
@@ -1003,7 +999,15 @@ if __name__ == '__main__':
1003
  image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
1004
  example_img_names = os.listdir(image_folder)
1005
  render_model, sample_steps, DiT_model, \
1006
- vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model, pipeline_sd = model_define()
1007
-
 
 
 
 
 
 
 
 
1008
  demo_cam = False
1009
  launch_gradio_app()
 
396
 
397
  print(f"✅ High-quality MP4 video has been generated: {output_video}")
398
 
399
+
400
  def model_define():
401
  args = get_args()
402
  set_env(args.seed)
403
  input_process_model = Process(cfg)
404
 
405
+ device = "cuda" if torch.cuda.is_available() else "cpu"
406
 
407
  weight_dtype = torch.float32
408
  logging.info(f"Running inference with {weight_dtype}")
 
440
  base_coff = torch.from_numpy(base_coff).float()
441
  Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
442
 
 
 
 
 
 
 
 
 
 
 
443
  return motion_aware_render_model, sample_steps, DiT_model, \
444
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
445
 
446
 
447
  def duplicate_batch(tensor, batch_size=2):
 
450
  return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
451
 
452
 
 
453
  @spaces.GPU(duration=200)
454
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
 
 
455
  """
456
  Generate avatars from input images.
457
 
 
478
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
479
  motion_base_dir = os.path.join(target_path, 'motions')
480
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
481
+ render_model =render_model.to(device)
482
+ image_encoder = image_encoder.to(device)
483
+ vae_triplane = vae_triplane.to(device)
484
+ dinov2 = dinov2.to(device)
485
+ Faceverse = Faceverse.to(device)
486
+ clip_image_processor = clip_image_processor.to(device)
487
+ dino_img_processor = dino_img_processor.to(device)
488
+ ws_avg = ws_avg.to(device)
489
+ DiT_model = DiT_model.to(device)
490
  if source_type == 'example':
491
  input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
492
  input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
 
653
  🎭 这个函数用于风格转换
654
  ✅ 你可以在这里填入你的风格化代码
655
  """
656
+ pipeline_sd =pipeline_sd.to(device)
657
  src_img_pil = Image.open(processed_image)
658
  img_name = os.path.basename(processed_image)
659
  save_dir = os.path.join(save_base, 'style_img')
 
999
  image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
1000
  example_img_names = os.listdir(image_folder)
1001
  render_model, sample_steps, DiT_model, \
1002
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
1003
+ controlnet_path = './pretrained_model/control'
1004
+ controlnet = ControlNetModel.from_pretrained(
1005
+ controlnet_path, torch_dtype=torch.float16
1006
+ )
1007
+ sd_path = './pretrained_model/sd21'
1008
+ pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
1009
+ sd_path, torch_dtype=torch.float16,
1010
+ use_safetensors=True, controlnet=controlnet, variant="fp16"
1011
+ ).to(device)
1012
  demo_cam = False
1013
  launch_gradio_app()