Spaces:
Runtime error
Runtime error
Update FantasyTalking/infer.py
Browse files- FantasyTalking/infer.py +38 -25
FantasyTalking/infer.py
CHANGED
|
@@ -127,7 +127,7 @@ def parse_args():
|
|
| 127 |
|
| 128 |
|
| 129 |
def load_models(args):
|
| 130 |
-
|
| 131 |
model_manager = ModelManager(device="cpu")
|
| 132 |
model_manager.load_models(
|
| 133 |
[
|
|
@@ -144,50 +144,63 @@ def load_models(args):
|
|
| 144 |
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
|
| 145 |
f"{args.wan_model_dir}/Wan2.1_VAE.pth",
|
| 146 |
],
|
| 147 |
-
|
| 148 |
-
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
|
| 149 |
)
|
|
|
|
|
|
|
| 150 |
pipe = WanVideoPipeline.from_model_manager(
|
| 151 |
model_manager, torch_dtype=torch.bfloat16, device="cuda"
|
| 152 |
)
|
| 153 |
|
| 154 |
-
|
| 155 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
| 156 |
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
pipe.enable_vram_management(
|
| 160 |
-
num_persistent_param_in_dit=args.num_persistent_param_in_dit
|
| 161 |
-
)
|
| 162 |
|
| 163 |
-
|
| 164 |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
|
| 165 |
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
|
|
|
|
| 166 |
|
| 167 |
return pipe, fantasytalking, wav2vec_processor, wav2vec
|
| 168 |
|
| 169 |
|
| 170 |
def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
|
|
| 171 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 172 |
|
|
|
|
| 173 |
duration = librosa.get_duration(filename=args.audio_path)
|
|
|
|
|
|
|
| 174 |
num_frames = min(int(args.fps * duration // 4) * 4 + 5, args.max_num_frames)
|
|
|
|
| 175 |
|
|
|
|
| 176 |
audio_wav2vec_fea = get_audio_features(
|
| 177 |
wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames
|
| 178 |
)
|
|
|
|
|
|
|
|
|
|
| 179 |
image = resize_image_by_longest_edge(args.image_path, args.image_size)
|
| 180 |
width, height = image.size
|
|
|
|
| 181 |
|
|
|
|
| 182 |
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
|
| 183 |
pos_idx_ranges = fantasytalking.split_audio_sequence(
|
| 184 |
audio_proj_fea.size(1), num_frames=num_frames
|
| 185 |
)
|
| 186 |
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding(
|
| 187 |
audio_proj_fea, pos_idx_ranges, expand_length=4
|
| 188 |
-
)
|
|
|
|
| 189 |
|
| 190 |
-
|
| 191 |
video_audio = pipe(
|
| 192 |
prompt=args.prompt,
|
| 193 |
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
@@ -205,32 +218,32 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
|
| 205 |
audio_context_lens=audio_context_lens,
|
| 206 |
latents_num_frames=(num_frames - 1) // 4 + 1,
|
| 207 |
)
|
|
|
|
|
|
|
| 208 |
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 209 |
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
|
|
|
| 210 |
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
|
| 211 |
|
| 212 |
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
|
|
|
|
|
|
| 213 |
final_command = [
|
| 214 |
-
"ffmpeg",
|
| 215 |
-
"-
|
| 216 |
-
"-i",
|
| 217 |
-
save_path_tmp,
|
| 218 |
-
"-i",
|
| 219 |
-
args.audio_path,
|
| 220 |
-
"-c:v",
|
| 221 |
-
"libx264",
|
| 222 |
-
"-c:a",
|
| 223 |
-
"aac",
|
| 224 |
-
"-shortest",
|
| 225 |
-
save_path,
|
| 226 |
]
|
| 227 |
subprocess.run(final_command, check=True)
|
|
|
|
|
|
|
|
|
|
| 228 |
os.remove(save_path_tmp)
|
|
|
|
| 229 |
return save_path
|
| 230 |
|
| 231 |
|
| 232 |
if __name__ == "__main__":
|
|
|
|
| 233 |
args = parse_args()
|
| 234 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
| 235 |
-
|
| 236 |
-
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def load_models(args):
|
| 130 |
+
print("🔄 Loading Wan I2V models...")
|
| 131 |
model_manager = ModelManager(device="cpu")
|
| 132 |
model_manager.load_models(
|
| 133 |
[
|
|
|
|
| 144 |
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
|
| 145 |
f"{args.wan_model_dir}/Wan2.1_VAE.pth",
|
| 146 |
],
|
| 147 |
+
torch_dtype=torch.bfloat16,
|
|
|
|
| 148 |
)
|
| 149 |
+
print("✅ Wan I2V models loaded.")
|
| 150 |
+
|
| 151 |
pipe = WanVideoPipeline.from_model_manager(
|
| 152 |
model_manager, torch_dtype=torch.bfloat16, device="cuda"
|
| 153 |
)
|
| 154 |
|
| 155 |
+
print("🔄 Loading FantasyTalking model...")
|
| 156 |
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
|
| 157 |
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
|
| 158 |
+
print("✅ FantasyTalking model loaded.")
|
| 159 |
|
| 160 |
+
print("🧠 Enabling VRAM management...")
|
| 161 |
+
pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit)
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
print("🔄 Loading Wav2Vec2 processor and model...")
|
| 164 |
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
|
| 165 |
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda")
|
| 166 |
+
print("✅ Wav2Vec2 loaded.")
|
| 167 |
|
| 168 |
return pipe, fantasytalking, wav2vec_processor, wav2vec
|
| 169 |
|
| 170 |
|
| 171 |
def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
|
| 172 |
+
print("📁 Creating output directory...")
|
| 173 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 174 |
|
| 175 |
+
print(f"🔊 Getting duration of audio: {args.audio_path}")
|
| 176 |
duration = librosa.get_duration(filename=args.audio_path)
|
| 177 |
+
print(f"🎞️ Duration: {duration:.2f}s")
|
| 178 |
+
|
| 179 |
num_frames = min(int(args.fps * duration // 4) * 4 + 5, args.max_num_frames)
|
| 180 |
+
print(f"📽️ Calculated number of frames: {num_frames}")
|
| 181 |
|
| 182 |
+
print("🎧 Extracting audio features...")
|
| 183 |
audio_wav2vec_fea = get_audio_features(
|
| 184 |
wav2vec, wav2vec_processor, args.audio_path, args.fps, num_frames
|
| 185 |
)
|
| 186 |
+
print("✅ Audio features extracted.")
|
| 187 |
+
|
| 188 |
+
print("🖼️ Loading and resizing image...")
|
| 189 |
image = resize_image_by_longest_edge(args.image_path, args.image_size)
|
| 190 |
width, height = image.size
|
| 191 |
+
print(f"✅ Image resized to: {width}x{height}")
|
| 192 |
|
| 193 |
+
print("🔄 Projecting audio features...")
|
| 194 |
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea)
|
| 195 |
pos_idx_ranges = fantasytalking.split_audio_sequence(
|
| 196 |
audio_proj_fea.size(1), num_frames=num_frames
|
| 197 |
)
|
| 198 |
audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding(
|
| 199 |
audio_proj_fea, pos_idx_ranges, expand_length=4
|
| 200 |
+
)
|
| 201 |
+
print("✅ Audio features projected and split.")
|
| 202 |
|
| 203 |
+
print("🚀 Generating video from image + audio...")
|
| 204 |
video_audio = pipe(
|
| 205 |
prompt=args.prompt,
|
| 206 |
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
|
|
|
| 218 |
audio_context_lens=audio_context_lens,
|
| 219 |
latents_num_frames=(num_frames - 1) // 4 + 1,
|
| 220 |
)
|
| 221 |
+
print("✅ Video frames generated.")
|
| 222 |
+
|
| 223 |
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 224 |
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
| 225 |
+
print(f"💾 Saving temporary video without audio to: {save_path_tmp}")
|
| 226 |
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5)
|
| 227 |
|
| 228 |
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4"
|
| 229 |
+
print(f"🔊 Merging video with audio using FFmpeg...")
|
| 230 |
+
|
| 231 |
final_command = [
|
| 232 |
+
"ffmpeg", "-y", "-i", save_path_tmp, "-i", args.audio_path,
|
| 233 |
+
"-c:v", "libx264", "-c:a", "aac", "-shortest", save_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
]
|
| 235 |
subprocess.run(final_command, check=True)
|
| 236 |
+
print(f"✅ Final video saved to: {save_path}")
|
| 237 |
+
|
| 238 |
+
print("🧹 Removing temporary video file...")
|
| 239 |
os.remove(save_path_tmp)
|
| 240 |
+
|
| 241 |
return save_path
|
| 242 |
|
| 243 |
|
| 244 |
if __name__ == "__main__":
|
| 245 |
+
print("🚦 Starting main script...")
|
| 246 |
args = parse_args()
|
| 247 |
pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
|
| 248 |
+
video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
|
| 249 |
+
print(f"🎉 Done! Final video path: {video_path}")
|