wedyanessam commited on
Commit
39618e2
Β·
verified Β·
1 Parent(s): ed4ad79

Update FantasyTalking/infer.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/infer.py +23 -112
FantasyTalking/infer.py CHANGED
@@ -18,112 +18,25 @@ from FantasyTalking.utils import get_audio_features, resize_image_by_longest_edg
18
 
19
 
20
  def parse_args():
21
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
22
-
23
- parser.add_argument(
24
- "--wan_model_dir",
25
- type=str,
26
- default="./models/Wan2.1-I2V-14B-720P",
27
- required=False,
28
- help="The dir of the Wan I2V 14B model.",
29
- )
30
- parser.add_argument(
31
- "--fantasytalking_model_path",
32
- type=str,
33
- default="./models/fantasytalking_model.ckpt",
34
- required=False,
35
- help="The .ckpt path of fantasytalking model.",
36
- )
37
- parser.add_argument(
38
- "--wav2vec_model_dir",
39
- type=str,
40
- default="./models/wav2vec2-base-960h",
41
- required=False,
42
- help="The dir of wav2vec model.",
43
- )
44
-
45
- parser.add_argument(
46
- "--image_path",
47
- type=str,
48
- default="./assets/images/woman.png",
49
- required=False,
50
- help="The path of the image.",
51
- )
52
-
53
- parser.add_argument(
54
- "--audio_path",
55
- type=str,
56
- default="./assets/audios/woman.wav",
57
- required=False,
58
- help="The path of the audio.",
59
- )
60
- parser.add_argument(
61
- "--prompt",
62
- type=str,
63
- default="A woman is talking.",
64
- required=False,
65
- help="prompt.",
66
- )
67
- parser.add_argument(
68
- "--output_dir",
69
- type=str,
70
- default="./output",
71
- help="Dir to save the model.",
72
- )
73
- parser.add_argument(
74
- "--image_size",
75
- type=int,
76
- default=512,
77
- help="The image will be resized proportionally to this size.",
78
- )
79
- parser.add_argument(
80
- "--audio_scale",
81
- type=float,
82
- default=1.0,
83
- help="Audio condition injection weight",
84
- )
85
- parser.add_argument(
86
- "--prompt_cfg_scale",
87
- type=float,
88
- default=5.0,
89
- required=False,
90
- help="Prompt cfg scale",
91
- )
92
- parser.add_argument(
93
- "--audio_cfg_scale",
94
- type=float,
95
- default=5.0,
96
- required=False,
97
- help="Audio cfg scale",
98
- )
99
- parser.add_argument(
100
- "--max_num_frames",
101
- type=int,
102
- default=81,
103
- required=False,
104
- help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated.",
105
- )
106
- parser.add_argument(
107
- "--fps",
108
- type=int,
109
- default=23,
110
- required=False,
111
- )
112
- parser.add_argument(
113
- "--num_persistent_param_in_dit",
114
- type=int,
115
- default=None,
116
- required=False,
117
- help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
118
- )
119
- parser.add_argument(
120
- "--seed",
121
- type=int,
122
- default=1111,
123
- required=False,
124
- )
125
- args = parser.parse_args()
126
- return args
127
 
128
 
129
  def load_models(args):
@@ -148,9 +61,7 @@ def load_models(args):
148
  )
149
  print("βœ… Wan I2V models loaded.")
150
 
151
- pipe = WanVideoPipeline.from_model_manager(
152
- model_manager, torch_dtype=torch.bfloat16, device="cuda"
153
- )
154
 
155
  print("πŸ”„ Loading FantasyTalking model...")
156
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
@@ -175,7 +86,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
175
  print(f"πŸ”Š Getting duration of audio: {args.audio_path}")
176
  duration = librosa.get_duration(filename=args.audio_path)
177
  print(f"🎞️ Duration: {duration:.2f}s")
178
-
179
  latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
180
  num_frames = (latents_num_frames - 1) * 4
181
  print(f"πŸ“½οΈ Calculated number of frames: {num_frames}")
@@ -217,7 +128,7 @@ def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
217
  audio_cfg_scale=args.audio_cfg_scale,
218
  audio_proj=audio_proj_split,
219
  audio_context_lens=audio_context_lens,
220
- latents_num_frames=(num_frames - 1) // 4 + 1,
221
  )
222
  print("βœ… Video frames generated.")
223
 
@@ -247,4 +158,4 @@ if __name__ == "__main__":
247
  args = parse_args()
248
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
249
  video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
250
- print(f"πŸŽ‰ Done! Final video path: {video_path}")
 
18
 
19
 
20
  def parse_args():
21
+ parser = argparse.ArgumentParser(description="FantasyTalking Video Generator")
22
+
23
+ parser.add_argument("--wan_model_dir", type=str, default="./models/Wan2.1-I2V-14B-720P")
24
+ parser.add_argument("--fantasytalking_model_path", type=str, default="./models/fantasytalking_model.ckpt")
25
+ parser.add_argument("--wav2vec_model_dir", type=str, default="./models/wav2vec2-base-960h")
26
+ parser.add_argument("--image_path", type=str, default="./assets/images/woman.png")
27
+ parser.add_argument("--audio_path", type=str, default="./assets/audios/woman.wav")
28
+ parser.add_argument("--prompt", type=str, default="A woman is talking.")
29
+ parser.add_argument("--output_dir", type=str, default="./output")
30
+ parser.add_argument("--image_size", type=int, default=512)
31
+ parser.add_argument("--audio_scale", type=float, default=1.0)
32
+ parser.add_argument("--prompt_cfg_scale", type=float, default=5.0)
33
+ parser.add_argument("--audio_cfg_scale", type=float, default=5.0)
34
+ parser.add_argument("--max_num_frames", type=int, default=81)
35
+ parser.add_argument("--fps", type=int, default=23)
36
+ parser.add_argument("--num_persistent_param_in_dit", type=int, default=None)
37
+ parser.add_argument("--seed", type=int, default=1111)
38
+
39
+ return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def load_models(args):
 
61
  )
62
  print("βœ… Wan I2V models loaded.")
63
 
64
+ pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
 
 
65
 
66
  print("πŸ”„ Loading FantasyTalking model...")
67
  fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
 
86
  print(f"πŸ”Š Getting duration of audio: {args.audio_path}")
87
  duration = librosa.get_duration(filename=args.audio_path)
88
  print(f"🎞️ Duration: {duration:.2f}s")
89
+
90
  latents_num_frames = min(int(duration * args.fps / 4), args.max_num_frames // 4)
91
  num_frames = (latents_num_frames - 1) * 4
92
  print(f"πŸ“½οΈ Calculated number of frames: {num_frames}")
 
128
  audio_cfg_scale=args.audio_cfg_scale,
129
  audio_proj=audio_proj_split,
130
  audio_context_lens=audio_context_lens,
131
+ latents_num_frames=latents_num_frames,
132
  )
133
  print("βœ… Video frames generated.")
134
 
 
158
  args = parse_args()
159
  pipe, fantasytalking, wav2vec_processor, wav2vec = load_models(args)
160
  video_path = main(args, pipe, fantasytalking, wav2vec_processor, wav2vec)
161
+ print(f"πŸŽ‰ Done! Final video path: {video_path}")