Spaces:
Running
on
Zero
Running
on
Zero
刘虹雨
commited on
Commit
·
af31c35
1
Parent(s):
42d9724
update code
Browse files
app.py
CHANGED
@@ -56,19 +56,7 @@ else:
|
|
56 |
|
57 |
import torch
|
58 |
|
59 |
-
|
60 |
-
torch.cuda.set_device(0)
|
61 |
-
|
62 |
-
# 显式初始化 CUDA(通常是可选的,但在多线程中有助于避免问题)
|
63 |
-
torch.cuda.init()
|
64 |
-
|
65 |
-
# 测试
|
66 |
-
print("CUDA available:", torch.cuda.is_available())
|
67 |
-
print("Current device:", torch.cuda.current_device())
|
68 |
-
print("Device name:", torch.cuda.get_device_name(0))
|
69 |
-
print("CUDA_HOME =", os.environ.get("CUDA_HOME"))
|
70 |
-
from torch.utils.cpp_extension import CUDA_HOME
|
71 |
-
print("CUDA_HOME from PyTorch:", CUDA_HOME)
|
72 |
import argparse
|
73 |
import json
|
74 |
import random
|
@@ -264,6 +252,7 @@ def generate_samples(DiT_model, cfg_scale, sample_steps, clip_feature, dino_feat
|
|
264 |
|
265 |
|
266 |
def load_motion_aware_render_model(ckpt_path, device):
|
|
|
267 |
"""Load the motion-aware render model from a checkpoint."""
|
268 |
logging.info("Loading motion-aware render model...")
|
269 |
with dnnlib.util.open_url(ckpt_path, 'rb') as f:
|
@@ -407,13 +396,14 @@ def images_to_video(image_folder, output_video, fps=30):
|
|
407 |
|
408 |
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
409 |
|
410 |
-
|
411 |
def model_define():
|
412 |
args = get_args()
|
413 |
set_env(args.seed)
|
414 |
input_process_model = Process(cfg)
|
415 |
|
416 |
-
device = "cuda"
|
|
|
417 |
weight_dtype = torch.float32
|
418 |
logging.info(f"Running inference with {weight_dtype}")
|
419 |
|
@@ -450,8 +440,18 @@ def model_define():
|
|
450 |
base_coff = torch.from_numpy(base_coff).float()
|
451 |
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
return motion_aware_render_model, sample_steps, DiT_model, \
|
454 |
-
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model
|
455 |
|
456 |
|
457 |
def duplicate_batch(tensor, batch_size=2):
|
@@ -463,6 +463,8 @@ def duplicate_batch(tensor, batch_size=2):
|
|
463 |
@torch.inference_mode()
|
464 |
@spaces.GPU(duration=200)
|
465 |
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
|
|
|
|
|
466 |
"""
|
467 |
Generate avatars from input images.
|
468 |
|
@@ -650,7 +652,7 @@ def process_image(input_image, source_type, is_style, save_dir):
|
|
650 |
imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
|
651 |
return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
|
652 |
|
653 |
-
|
654 |
def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
|
655 |
"""
|
656 |
🎭 这个函数用于风格转换
|
@@ -1001,15 +1003,7 @@ if __name__ == '__main__':
|
|
1001 |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
|
1002 |
example_img_names = os.listdir(image_folder)
|
1003 |
render_model, sample_steps, DiT_model, \
|
1004 |
-
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model = model_define()
|
1005 |
-
|
1006 |
-
controlnet = ControlNetModel.from_pretrained(
|
1007 |
-
controlnet_path, torch_dtype=torch.float16
|
1008 |
-
)
|
1009 |
-
sd_path = './pretrained_model/sd21'
|
1010 |
-
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
1011 |
-
sd_path, torch_dtype=torch.float16,
|
1012 |
-
use_safetensors=True, controlnet=controlnet, variant="fp16"
|
1013 |
-
).to(device)
|
1014 |
demo_cam = False
|
1015 |
launch_gradio_app()
|
|
|
56 |
|
57 |
import torch
|
58 |
|
59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
import argparse
|
61 |
import json
|
62 |
import random
|
|
|
252 |
|
253 |
|
254 |
def load_motion_aware_render_model(ckpt_path, device):
|
255 |
+
|
256 |
"""Load the motion-aware render model from a checkpoint."""
|
257 |
logging.info("Loading motion-aware render model...")
|
258 |
with dnnlib.util.open_url(ckpt_path, 'rb') as f:
|
|
|
396 |
|
397 |
print(f"✅ High-quality MP4 video has been generated: {output_video}")
|
398 |
|
399 |
+
@spaces.GPU(duration=100)
|
400 |
def model_define():
|
401 |
args = get_args()
|
402 |
set_env(args.seed)
|
403 |
input_process_model = Process(cfg)
|
404 |
|
405 |
+
device = "cuda"
|
406 |
+
|
407 |
weight_dtype = torch.float32
|
408 |
logging.info(f"Running inference with {weight_dtype}")
|
409 |
|
|
|
440 |
base_coff = torch.from_numpy(base_coff).float()
|
441 |
Faceverse = Faceverse_manager(device=device, base_coeff=base_coff)
|
442 |
|
443 |
+
controlnet_path = './pretrained_model/control'
|
444 |
+
controlnet = ControlNetModel.from_pretrained(
|
445 |
+
controlnet_path, torch_dtype=torch.float16
|
446 |
+
)
|
447 |
+
sd_path = './pretrained_model/sd21'
|
448 |
+
pipeline_sd = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
449 |
+
sd_path, torch_dtype=torch.float16,
|
450 |
+
use_safetensors=True, controlnet=controlnet, variant="fp16"
|
451 |
+
).to(device)
|
452 |
+
|
453 |
return motion_aware_render_model, sample_steps, DiT_model, \
|
454 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, triplane_std, triplane_mean, ws_avg, Faceverse, device, input_process_model,pipeline_sd
|
455 |
|
456 |
|
457 |
def duplicate_batch(tensor, batch_size=2):
|
|
|
463 |
@torch.inference_mode()
|
464 |
@spaces.GPU(duration=200)
|
465 |
def avatar_generation(items, save_path_base, video_path_input, source_type, is_styled, styled_img):
|
466 |
+
|
467 |
+
|
468 |
"""
|
469 |
Generate avatars from input images.
|
470 |
|
|
|
652 |
imge_dir = os.path.join(save_dir, 'processed_img/dataset/images512x512/input_image', img_name)
|
653 |
return imge_dir, source_type # 这里替换成 处理用户上传图片的逻辑
|
654 |
|
655 |
+
@spaces.GPU(duration=100)
|
656 |
def style_transfer(processed_image, style_prompt, cfg, strength, save_base):
|
657 |
"""
|
658 |
🎭 这个函数用于风格转换
|
|
|
1003 |
image_folder = "./demo_data/source_img/img_generate_different_domain/images512x512/demo_imgs"
|
1004 |
example_img_names = os.listdir(image_folder)
|
1005 |
render_model, sample_steps, DiT_model, \
|
1006 |
+
vae_triplane, image_encoder, dinov2, dino_img_processor, clip_image_processor, std, mean, ws_avg, Faceverse, device, input_process_model, pipeline_sd = model_define()
|
1007 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1008 |
demo_cam = False
|
1009 |
launch_gradio_app()
|