Spaces:
Running
on
Zero
Running
on
Zero
刘虹雨
commited on
Commit
·
5834ebe
1
Parent(s):
af31c35
update code
Browse files
app.py
CHANGED
@@ -396,13 +396,13 @@ def images_to_video(image_folder, output_video, fps=30):
|
|
396 |
|
397 |
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
398 |
|
399 |
-
|
400 |
def model_define():
|
401 |
args = get_args()
|
402 |
set_env(args.seed)
|
403 |
input_process_model = Process(cfg)
|
404 |
|
405 |
-
device = "cuda"
|
406 |
|
407 |
weight_dtype = torch.float32
|
408 |
logging.info(f"Running inference with {weight_dtype}")
|
@@ -440,18 +440,8 @@ def model_define():
|
|
440 |
base_coff = torch.from_numpy(base_coff).float()
|
441 |
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
442 |
|
443 |
-
controlnet_path = './pretrained_model/control'
|
444 |
-
controlnet = ControlNetModel.from_pretrained(
|
445 |
-
controlnet_path, torch_dtype=torch.float16
|
446 |
-
)
|
447 |
-
sd_path = './pretrained_model/sd21'
|
448 |
-
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
449 |
-
sd_path, torch_dtype=torch.float16,
|
450 |
-
use_safetensors=True, controlnet=controlnet, variant="fp16"
|
451 |
-
).to(device)
|
452 |
-
|
453 |
return motion_aware_render_model, sample_steps, DiT_model, \
|
454 |
-
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
|
455 |
|
456 |
|
457 |
def duplicate_batch(tensor, batch_size=2):
|
@@ -460,11 +450,8 @@ def duplicate_batch(tensor, batch_size=2):
|
|
460 |
return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
|
461 |
|
462 |
|
463 |
-
@torch.inference_mode()
|
464 |
@spaces.GPU(duration=200)
|
465 |
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
|
466 |
-
|
467 |
-
|
468 |
"""
|
469 |
Generate avatars from input images.
|
470 |
|
@@ -491,7 +478,15 @@ def avatar_generation(items, save_path_base, video_path_input, source_type, is_s
|
|
491 |
exp_img_base_dir = os.path.join(target_path, 'images512x512')
|
492 |
motion_base_dir = os.path.join(target_path, 'motions')
|
493 |
label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
|
494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
495 |
if source_type == 'example':
|
496 |
input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
|
497 |
input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
|
@@ -658,6 +653,7 @@ def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
|
|
658 |
🎭 这个函数用于风格转换
|
659 |
✅ 你可以在这里填入你的风格化代码
|
660 |
"""
|
|
|
661 |
src_img_pil = Image.open(processed_image)
|
662 |
img_name = os.path.basename(processed_image)
|
663 |
save_dir = os.path.join(save_base, 'style_img')
|
@@ -1003,7 +999,15 @@ if __name__ == '__main__':
|
|
1003 |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
|
1004 |
example_img_names = os.listdir(image_folder)
|
1005 |
render_model, sample_steps, DiT_model, \
|
1006 |
-
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model
|
1007 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1008 |
demo_cam = False
|
1009 |
launch_gradio_app()
|
|
|
396 |
|
397 |
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
398 |
|
399 |
+
|
400 |
def model_define():
|
401 |
args = get_args()
|
402 |
set_env(args.seed)
|
403 |
input_process_model = Process(cfg)
|
404 |
|
405 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
406 |
|
407 |
weight_dtype = torch.float32
|
408 |
logging.info(f"Running inference with {weight_dtype}")
|
|
|
440 |
base_coff = torch.from_numpy(base_coff).float()
|
441 |
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
442 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
443 |
return motion_aware_render_model, sample_steps, DiT_model, \
|
444 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
|
445 |
|
446 |
|
447 |
def duplicate_batch(tensor, batch_size=2):
|
|
|
450 |
return tensor.repeat(batch_size, *([1] * (tensor.dim() - 1))) # 复制 batch 维度
|
451 |
|
452 |
|
|
|
453 |
@spaces.GPU(duration=200)
|
454 |
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
|
|
|
|
|
455 |
"""
|
456 |
Generate avatars from input images.
|
457 |
|
|
|
478 |
exp_img_base_dir = os.path.join(target_path, 'images512x512')
|
479 |
motion_base_dir = os.path.join(target_path, 'motions')
|
480 |
label_file_test = os.path.join(target_path, 'images512x512/dataset_realcam.json')
|
481 |
+
render_model =render_model.to(device)
|
482 |
+
image_encoder = image_encoder.to(device)
|
483 |
+
vae_triplane = vae_triplane.to(device)
|
484 |
+
dinov2 = dinov2.to(device)
|
485 |
+
Faceverse = Faceverse.to(device)
|
486 |
+
clip_image_processor = clip_image_processor.to(device)
|
487 |
+
dino_img_processor = dino_img_processor.to(device)
|
488 |
+
ws_avg = ws_avg.to(device)
|
489 |
+
DiT_model = DiT_model.to(device)
|
490 |
if source_type == 'example':
|
491 |
input_img_fvid = './demo_data/source_img/img_generate_different_domain/coeffs/demo_imgs'
|
492 |
input_img_motion = './demo_data/source_img/img_generate_different_domain/motions/demo_imgs'
|
|
|
653 |
🎭 这个函数用于风格转换
|
654 |
✅ 你可以在这里填入你的风格化代码
|
655 |
"""
|
656 |
+
pipeline_sd =pipeline_sd.to(device)
|
657 |
src_img_pil = Image.open(processed_image)
|
658 |
img_name = os.path.basename(processed_image)
|
659 |
save_dir = os.path.join(save_base, 'style_img')
|
|
|
999 |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
|
1000 |
example_img_names = os.listdir(image_folder)
|
1001 |
render_model, sample_steps, DiT_model, \
|
1002 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
|
1003 |
+
controlnet_path = './pretrained_model/control'
|
1004 |
+
controlnet = ControlNetModel.from_pretrained(
|
1005 |
+
controlnet_path, torch_dtype=torch.float16
|
1006 |
+
)
|
1007 |
+
sd_path = './pretrained_model/sd21'
|
1008 |
+
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
1009 |
+
sd_path, torch_dtype=torch.float16,
|
1010 |
+
use_safetensors=True, controlnet=controlnet, variant="fp16"
|
1011 |
+
).to(device)
|
1012 |
demo_cam = False
|
1013 |
launch_gradio_app()
|