刘虹雨 commited on
Commit
cd8a54f
·
1 Parent(s): 5834ebe

update code

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -449,7 +449,7 @@ def duplicate_batch(tensor, batch_size=2):
449
  return None # 如果是 None,则直接返回
450
  return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
451
 
452
-
453
  @spaces.GPU(duration=200)
454
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
455
  """
@@ -478,13 +478,11 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
478
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
479
  motion_base_dir = os.path.join(target_path, 'motions')
480
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
481
- render_model =render_model.to(device)
482
- image_encoder = image_encoder.to(device)
483
- vae_triplane = vae_triplane.to(device)
484
- dinov2 = dinov2.to(device)
485
- Faceverse = Faceverse.to(device)
486
- clip_image_processor = clip_image_processor.to(device)
487
- dino_img_processor = dino_img_processor.to(device)
488
  ws_avg = ws_avg.to(device)
489
  DiT_model = DiT_model.to(device)
490
  if source_type == 'example':
@@ -631,6 +629,7 @@ def assert_input_image(input_image):
631
  raise gr.Error("No image selected or uploaded!")
632
 
633
  @spaces.GPU(duration=100)
 
634
  def process_image(input_image, source_type, is_style, save_dir):
635
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
636
  process_img_input_dir = os.path.join(save_dir, 'input_image')
@@ -648,12 +647,13 @@ def process_image(input_image, source_type, is_style, save_dir):
648
  return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
649
 
650
  @spaces.GPU(duration=100)
 
651
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
652
  """
653
  🎭 这个函数用于风格转换
654
  ✅ 你可以在这里填入你的风格化代码
655
  """
656
- pipeline_sd =pipeline_sd.to(device)
657
  src_img_pil = Image.open(processed_image)
658
  img_name = os.path.basename(processed_image)
659
  save_dir = os.path.join(save_base, 'style_img')
 
449
  return None # 如果是 None,则直接返回
450
  return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
451
 
452
+ @torch.no_grad()
453
  @spaces.GPU(duration=200)
454
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
455
  """
 
478
  exp_img_base_dir = os.path.join(target_path, 'images512x512')
479
  motion_base_dir = os.path.join(target_path, 'motions')
480
  label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
481
+ render_model.to(device)
482
+ image_encoder.to(device)
483
+ vae_triplane.to(device)
484
+ dinov2.to(device)
485
+ Faceverse.to(device)
 
 
486
  ws_avg = ws_avg.to(device)
487
  DiT_model = DiT_model.to(device)
488
  if source_type == 'example':
 
629
  raise gr.Error("No image selected or uploaded!")
630
 
631
  @spaces.GPU(duration=100)
632
+ @torch.no_grad()
633
  def process_image(input_image, source_type, is_style, save_dir):
634
  """ 🎯 处理 input_image,根据是否是示例图片执行不同逻辑 """
635
  process_img_input_dir = os.path.join(save_dir, 'input_image')
 
647
  return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
648
 
649
  @spaces.GPU(duration=100)
650
+ @torch.no_grad()
651
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
652
  """
653
  🎭 这个函数用于风格转换
654
  ✅ 你可以在这里填入你的风格化代码
655
  """
656
+ pipeline_sd.to(device)
657
  src_img_pil = Image.open(processed_image)
658
  img_name = os.path.basename(processed_image)
659
  save_dir = os.path.join(save_base, 'style_img')