刘虹雨 commited on
Commit
af31c35
·
1 Parent(s): 42d9724

update code

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -56,19 +56,7 @@ else:
56
 
57
  import torch
58
 
59
- # 设置当前默认 GPU 设备(推荐在 CUDA 初始化前设置)
60
- torch.cuda.set_device(0)
61
-
62
- # 显式初始化 CUDA(通常是可选的,但在多线程中有助于避免问题)
63
- torch.cuda.init()
64
-
65
- # 测试
66
- print("CUDA available:", torch.cuda.is_available())
67
- print("Current device:", torch.cuda.current_device())
68
- print("Device name:", torch.cuda.get_device_name(0))
69
- print("CUDA_HOME =", os.environ.get("CUDA_HOME"))
70
- from torch.utils.cpp_extension import CUDA_HOME
71
- print("CUDA_HOME from PyTorch:", CUDA_HOME)
72
  import argparse
73
  import json
74
  import random
@@ -264,6 +252,7 @@ def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feat
264
 
265
 
266
  def load_motion_aware_render_model(ckpt_path, device):
 
267
  """Load the motion-aware render model from a checkpoint."""
268
  logging.info("Loading motion-aware render model...")
269
  with dnnlib.util.open_url(ckpt_path, 'rb') as f:
@@ -407,13 +396,14 @@ def images_to_video(image_folder, output_video, fps=30):
407
 
408
  print(f"✅ High-quality MP4 video has been generated: {output_video}")
409
 
410
-
411
  def model_define():
412
  args = get_args()
413
  set_env(args.seed)
414
  input_process_model = Process(cfg)
415
 
416
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
417
  weight_dtype = torch.float32
418
  logging.info(f"Running inference with {weight_dtype}")
419
 
@@ -450,8 +440,18 @@ def model_define():
450
  base_coff = torch.from_numpy(base_coff).float()
451
  Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
452
 
 
 
 
 
 
 
 
 
 
 
453
  return motion_aware_render_model, sample_steps, DiT_model, \
454
- vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
455
 
456
 
457
  def duplicate_batch(tensor, batch_size=2):
@@ -463,6 +463,8 @@ def duplicate_batch(tensor, batch_size=2):
463
  @torch.inference_mode()
464
  @spaces.GPU(duration=200)
465
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
 
 
466
  """
467
  Generate avatars from input images.
468
 
@@ -650,7 +652,7 @@ def process_image(input_image, source_type, is_style, save_dir):
650
  imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
651
  return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
652
 
653
-
654
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
655
  """
656
  🎭 这个函数用于风格转换
@@ -1001,15 +1003,7 @@ if __name__ == '__main__':
1001
  image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
1002
  example_img_names = os.listdir(image_folder)
1003
  render_model, sample_steps, DiT_model, \
1004
- vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
1005
- controlnet_path = './pretrained_model/control'
1006
- controlnet = ControlNetModel.from_pretrained(
1007
- controlnet_path, torch_dtype=torch.float16
1008
- )
1009
- sd_path = './pretrained_model/sd21'
1010
- pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
1011
- sd_path, torch_dtype=torch.float16,
1012
- use_safetensors=True, controlnet=controlnet, variant="fp16"
1013
- ).to(device)
1014
  demo_cam = False
1015
  launch_gradio_app()
 
56
 
57
  import torch
58
 
59
+
 
 
 
 
 
 
 
 
 
 
 
 
60
  import argparse
61
  import json
62
  import random
 
252
 
253
 
254
  def load_motion_aware_render_model(ckpt_path, device):
255
+
256
  """Load the motion-aware render model from a checkpoint."""
257
  logging.info("Loading motion-aware render model...")
258
  with dnnlib.util.open_url(ckpt_path, 'rb') as f:
 
396
 
397
  print(f"✅ High-quality MP4 video has been generated: {output_video}")
398
 
399
+ @spaces.GPU(duration=100)
400
  def model_define():
401
  args = get_args()
402
  set_env(args.seed)
403
  input_process_model = Process(cfg)
404
 
405
+ device = "cuda"
406
+
407
  weight_dtype = torch.float32
408
  logging.info(f"Running inference with {weight_dtype}")
409
 
 
440
  base_coff = torch.from_numpy(base_coff).float()
441
  Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
442
 
443
+ controlnet_path = './pretrained_model/control'
444
+ controlnet = ControlNetModel.from_pretrained(
445
+ controlnet_path, torch_dtype=torch.float16
446
+ )
447
+ sd_path = './pretrained_model/sd21'
448
+ pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
449
+ sd_path, torch_dtype=torch.float16,
450
+ use_safetensors=True, controlnet=controlnet, variant="fp16"
451
+ ).to(device)
452
+
453
  return motion_aware_render_model, sample_steps, DiT_model, \
454
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model,pipeline_sd
455
 
456
 
457
  def duplicate_batch(tensor, batch_size=2):
 
463
  @torch.inference_mode()
464
  @spaces.GPU(duration=200)
465
  def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
466
+
467
+
468
  """
469
  Generate avatars from input images.
470
 
 
652
  imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
653
  return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
654
 
655
+ @spaces.GPU(duration=100)
656
  def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
657
  """
658
  🎭 这个函数用于风格转换
 
1003
  image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
1004
  example_img_names = os.listdir(image_folder)
1005
  render_model, sample_steps, DiT_model, \
1006
+ vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model, pipeline_sd = model_define()
1007
+
 
 
 
 
 
 
 
 
1008
  demo_cam = False
1009
  launch_gradio_app()