Spaces:
Runtime error
Runtime error
Update FantasyTalking/infer.py
Browse files- FantasyTalking/infer.py +23 -112
FantasyTalking/infer.py
CHANGED
|
@@ -18,112 +18,25 @@ from FantasyTalking.utils import get_audio_features, resize_image_by_longest_edg
|
|
| 18 |
|
| 19 |
|
| 20 |
def parse_args():
|
| 21 |
-
parser = argparse.ArgumentParser(description="
|
| 22 |
-
|
| 23 |
-
parser.add_argument(
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
)
|
| 30 |
-
parser.add_argument(
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
-
parser.add_argument(
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
default="./models/wav2vec2-base-960h",
|
| 41 |
-
required=False,
|
| 42 |
-
help="The dir of wav2vec model.",
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
"--image_path",
|
| 47 |
-
type=str,
|
| 48 |
-
default="./assets/images/woman.png",
|
| 49 |
-
required=False,
|
| 50 |
-
help="The path of the image.",
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
parser.add_argument(
|
| 54 |
-
"--audio_path",
|
| 55 |
-
type=str,
|
| 56 |
-
default="./assets/audios/woman.wav",
|
| 57 |
-
required=False,
|
| 58 |
-
help="The path of the audio.",
|
| 59 |
-
)
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--prompt",
|
| 62 |
-
type=str,
|
| 63 |
-
default="A woman is talking.",
|
| 64 |
-
required=False,
|
| 65 |
-
help="prompt.",
|
| 66 |
-
)
|
| 67 |
-
parser.add_argument(
|
| 68 |
-
"--output_dir",
|
| 69 |
-
type=str,
|
| 70 |
-
default="./output",
|
| 71 |
-
help="Dir to save the model.",
|
| 72 |
-
)
|
| 73 |
-
parser.add_argument(
|
| 74 |
-
"--image_size",
|
| 75 |
-
type=int,
|
| 76 |
-
default=512,
|
| 77 |
-
help="The image will be resized proportionally to this size.",
|
| 78 |
-
)
|
| 79 |
-
parser.add_argument(
|
| 80 |
-
"--audio_scale",
|
| 81 |
-
type=float,
|
| 82 |
-
default=1.0,
|
| 83 |
-
help="Audio condition injection weight",
|
| 84 |
-
)
|
| 85 |
-
parser.add_argument(
|
| 86 |
-
"--prompt_cfg_scale",
|
| 87 |
-
type=float,
|
| 88 |
-
default=5.0,
|
| 89 |
-
required=False,
|
| 90 |
-
help="Prompt cfg scale",
|
| 91 |
-
)
|
| 92 |
-
parser.add_argument(
|
| 93 |
-
"--audio_cfg_scale",
|
| 94 |
-
type=float,
|
| 95 |
-
default=5.0,
|
| 96 |
-
required=False,
|
| 97 |
-
help="Audio cfg scale",
|
| 98 |
-
)
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
"--max_num_frames",
|
| 101 |
-
type=int,
|
| 102 |
-
default=81,
|
| 103 |
-
required=False,
|
| 104 |
-
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
|
| 105 |
-
)
|
| 106 |
-
parser.add_argument(
|
| 107 |
-
"--fps",
|
| 108 |
-
type=int,
|
| 109 |
-
default=23,
|
| 110 |
-
required=False,
|
| 111 |
-
)
|
| 112 |
-
parser.add_argument(
|
| 113 |
-
"--num_persistent_param_in_dit",
|
| 114 |
-
type=int,
|
| 115 |
-
default=None,
|
| 116 |
-
required=False,
|
| 117 |
-
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
|
| 118 |
-
)
|
| 119 |
-
parser.add_argument(
|
| 120 |
-
"--seed",
|
| 121 |
-
type=int,
|
| 122 |
-
default=1111,
|
| 123 |
-
required=False,
|
| 124 |
-
)
|
| 125 |
-
args = parser.parse_args()
|
| 126 |
-
return args
|
| 127 |
|
| 128 |
|
| 129 |
def load_models(args):
|
|
@@ -148,9 +61,7 @@ def load_models(args):
|
|
| 148 |
)
|
| 149 |
print("β
Wan I2V models loaded.")
|
| 150 |
|
| 151 |
-
pipe = WanVideoPipeline.from_model_manager(
|
| 152 |
-
model_manager, torch_dtype=torch.bfloat16, device="cuda"
|
| 153 |
-
)
|
| 154 |
|
| 155 |
print("π Loading FantasyTalking model...")
|
| 156 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
|
@@ -175,7 +86,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
| 175 |
print(f"π Getting duration of audio: {args.audio_path}")
|
| 176 |
duration = librosa.get_duration(filename=args.audio_path)
|
| 177 |
print(f"ποΈ Duration: {duration:.2f}s")
|
| 178 |
-
|
| 179 |
latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
|
| 180 |
num_frames = (latents_num_frames - 1) * 4
|
| 181 |
print(f"π½οΈ Calculated number of frames: {num_frames}")
|
|
@@ -217,7 +128,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
| 217 |
audio_cfg_scale=args.audio_cfg_scale,
|
| 218 |
audio_proj=audio_proj_split,
|
| 219 |
audio_context_lens=audio_context_lens,
|
| 220 |
-
latents_num_frames=
|
| 221 |
)
|
| 222 |
print("β
Video frames generated.")
|
| 223 |
|
|
@@ -247,4 +158,4 @@ if __name__ == "__main__":
|
|
| 247 |
args = parse_args()
|
| 248 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
| 249 |
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
|
| 250 |
-
print(f"π Done! Final video path: {video_path}")
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def parse_args():
|
| 21 |
+
parser = argparse.ArgumentParser(description="FantasyTalking Video Generator")
|
| 22 |
+
|
| 23 |
+
parser.add_argument("--wan_model_dir", type=str, default="./models/Wan2.1-I2V-14B-720P")
|
| 24 |
+
parser.add_argument("--fantasytalking_model_path", type=str, default="./models/fantasytalking_model.ckpt")
|
| 25 |
+
parser.add_argument("--wav2vec_model_dir", type=str, default="./models/wav2vec2-base-960h")
|
| 26 |
+
parser.add_argument("--image_path", type=str, default="./assets/images/woman.png")
|
| 27 |
+
parser.add_argument("--audio_path", type=str, default="./assets/audios/woman.wav")
|
| 28 |
+
parser.add_argument("--prompt", type=str, default="A woman is talking.")
|
| 29 |
+
parser.add_argument("--output_dir", type=str, default="./output")
|
| 30 |
+
parser.add_argument("--image_size", type=int, default=512)
|
| 31 |
+
parser.add_argument("--audio_scale", type=float, default=1.0)
|
| 32 |
+
parser.add_argument("--prompt_cfg_scale", type=float, default=5.0)
|
| 33 |
+
parser.add_argument("--audio_cfg_scale", type=float, default=5.0)
|
| 34 |
+
parser.add_argument("--max_num_frames", type=int, default=81)
|
| 35 |
+
parser.add_argument("--fps", type=int, default=23)
|
| 36 |
+
parser.add_argument("--num_persistent_param_in_dit", type=int, default=None)
|
| 37 |
+
parser.add_argument("--seed", type=int, default=1111)
|
| 38 |
+
|
| 39 |
+
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def load_models(args):
|
|
|
|
| 61 |
)
|
| 62 |
print("β
Wan I2V models loaded.")
|
| 63 |
|
| 64 |
+
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
|
|
|
|
|
|
| 65 |
|
| 66 |
print("π Loading FantasyTalking model...")
|
| 67 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
|
|
|
| 86 |
print(f"π Getting duration of audio: {args.audio_path}")
|
| 87 |
duration = librosa.get_duration(filename=args.audio_path)
|
| 88 |
print(f"ποΈ Duration: {duration:.2f}s")
|
| 89 |
+
|
| 90 |
latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
|
| 91 |
num_frames = (latents_num_frames - 1) * 4
|
| 92 |
print(f"π½οΈ Calculated number of frames: {num_frames}")
|
|
|
|
| 128 |
audio_cfg_scale=args.audio_cfg_scale,
|
| 129 |
audio_proj=audio_proj_split,
|
| 130 |
audio_context_lens=audio_context_lens,
|
| 131 |
+
latents_num_frames=latents_num_frames,
|
| 132 |
)
|
| 133 |
print("β
Video frames generated.")
|
| 134 |
|
|
|
|
| 158 |
args = parse_args()
|
| 159 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
| 160 |
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
|
| 161 |
+
print(f"π Done! Final video path: {video_path}")
|