wedyanessam commited on
Commit
6af8e5d
·
verified ·
1 Parent(s): ce79c62

Update FantasyTalking/infer.py

Browse files
Files changed (1) hide show
  1. FantasyTalking/infer.py +28 -30
FantasyTalking/infer.py CHANGED
@@ -127,38 +127,35 @@ def parse_args():
127
 
128
 
129
  def load_models(args):
130
- # Load Wan I2V models
131
- model_manager = ModelManager(device="cpu")
132
- model_manager.load_models(
133
- [
134
- [
135
- f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
136
- f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
137
- f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
138
- f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
139
- f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
140
- f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
141
- f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
142
- ],
143
- f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
144
- f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
145
- f"{args.wan_model_dir}/Wan2.1_VAE.pth",
146
- ],
147
- # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
148
- torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization.
149
- )
150
- pipe = WanVideoPipeline.from_model_manager(
151
- model_manager, torch_dtype=torch.bfloat16, device="cuda"
152
- )
 
 
153
 
154
  # Load FantasyTalking weights
155
- fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda")
156
- fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit)
157
-
158
- # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required.
159
- pipe.enable_vram_management(
160
- num_persistent_param_in_dit=args.num_persistent_param_in_dit
161
- )
162
 
163
  # Load wav2vec models
164
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
@@ -167,6 +164,7 @@ def load_models(args):
167
  return pipe, fantasytalking, wav2vec_processor, wav2vec
168
 
169
 
 
170
  def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
171
  os.makedirs(args.output_dir, exist_ok=True)
172
 
 
127
 
128
 
129
  def load_models(args):
130
+ # Load Wan I2V models (تم التعليق عليها مؤقتاً)
131
+ # model_manager = ModelManager(device="cpu")
132
+ # model_manager.load_models(
133
+ # [
134
+ # [
135
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors",
136
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors",
137
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors",
138
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors",
139
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors",
140
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors",
141
+ # f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors",
142
+ # ],
143
+ # f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
144
+ # f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth",
145
+ # f"{args.wan_model_dir}/Wan2.1_VAE.pth",
146
+ # ],
147
+ # torch_dtype=torch.bfloat16,
148
+ # )
149
+ # pipe = WanVideoPipeline.from_model_manager(
150
+ # model_manager, torch_dtype=torch.bfloat16, device="cuda"
151
+ # )
152
+
153
+ # مبدئياً نحط pipe بـ None أو تقدرِ تحطي أي Placeholder مؤقت
154
+ pipe = None
155
 
156
  # Load FantasyTalking weights
157
+ fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to("cuda")
158
+ fantasytalking.load_audio_processor(args.fantasytalking_model_path, None)
 
 
 
 
 
159
 
160
  # Load wav2vec models
161
  wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir)
 
164
  return pipe, fantasytalking, wav2vec_processor, wav2vec
165
 
166
 
167
+
168
  def main(args, pipe, fantasytalking, wav2vec_processor, wav2vec):
169
  os.makedirs(args.output_dir, exist_ok=True)
170