Spaces:
Runtime error
Runtime error
Optimize memory usage
Browse files
app.py
CHANGED
|
@@ -208,32 +208,19 @@ seed_everything(seed, workers=True)
|
|
| 208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
| 209 |
model_config = json.load(f)
|
| 210 |
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
## speed by torch.compile
|
| 214 |
if args.compile:
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
if args.pretrained_ckpt_path:
|
| 218 |
-
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
|
| 219 |
-
|
| 220 |
-
if args.remove_pretransform_weight_norm == "pre_load":
|
| 221 |
-
remove_weight_norm_from_model(model.pretransform)
|
| 222 |
|
| 223 |
|
| 224 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
| 225 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
# Remove weight_norm from the pretransform if specified
|
| 229 |
-
if args.remove_pretransform_weight_norm == "post_load":
|
| 230 |
-
remove_weight_norm_from_model(model.pretransform)
|
| 231 |
-
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound.ckpt",repo_type="model")
|
| 232 |
-
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
| 233 |
-
# 加载模型权重时根据设备选择map_location
|
| 234 |
-
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
| 235 |
-
|
| 236 |
-
training_wrapper.to("cuda")
|
| 237 |
|
| 238 |
def get_video_duration(video_path):
|
| 239 |
video = VideoFileClip(video_path)
|
|
@@ -276,36 +263,36 @@ def synthesize_video_with_audio(video_file, caption, cot):
|
|
| 276 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
| 277 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
| 278 |
latent_seq_len = (int)(194/9*duration_sec)
|
| 279 |
-
|
| 280 |
|
| 281 |
metadata = [preprocessed_data]
|
| 282 |
|
| 283 |
batch_size = 1
|
| 284 |
length = latent_seq_len
|
| 285 |
with torch.amp.autocast(device):
|
| 286 |
-
conditioning =
|
| 287 |
|
| 288 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
| 289 |
-
conditioning['metaclip_features'][~video_exist] =
|
| 290 |
-
conditioning['sync_features'][~video_exist] =
|
| 291 |
|
| 292 |
yield "⏳ Inferring…", None
|
| 293 |
|
| 294 |
-
cond_inputs =
|
| 295 |
-
noise = torch.randn([batch_size,
|
| 296 |
with torch.amp.autocast(device):
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
elif training_wrapper.diffusion_objective == "rectified_flow":
|
| 301 |
import time
|
| 302 |
start_time = time.time()
|
| 303 |
-
fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
| 304 |
end_time = time.time()
|
| 305 |
execution_time = end_time - start_time
|
| 306 |
-
print(f"
|
| 307 |
-
|
| 308 |
-
|
|
|
|
| 309 |
|
| 310 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 311 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
|
|
|
| 208 |
with open("ThinkSound/configs/model_configs/thinksound.json") as f:
|
| 209 |
model_config = json.load(f)
|
| 210 |
|
| 211 |
+
diffusion_model = create_model_from_config(model_config)
|
| 212 |
+
ckpt_path = hf_hub_download(repo_id="FunAudioLLM/ThinkSound", filename="thinksound_light.ckpt",repo_type="model")
|
| 213 |
+
diffusion_model.load_state_dict(torch.load(ckpt_path))
|
| 214 |
+
diffusion_model.to(device)
|
| 215 |
|
| 216 |
## speed by torch.compile
|
| 217 |
if args.compile:
|
| 218 |
+
diffusion_model = torch.compile(diffusion_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
| 221 |
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
| 222 |
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
| 223 |
+
diffusion_model.pretransform.load_state_dict(load_vae_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
def get_video_duration(video_path):
|
| 226 |
video = VideoFileClip(video_path)
|
|
|
|
| 263 |
sync_seq_len = preprocessed_data['sync_features'].shape[0]
|
| 264 |
clip_seq_len = preprocessed_data['metaclip_features'].shape[0]
|
| 265 |
latent_seq_len = (int)(194/9*duration_sec)
|
| 266 |
+
diffusion_model.model.model.update_seq_lengths(latent_seq_len, clip_seq_len, sync_seq_len)
|
| 267 |
|
| 268 |
metadata = [preprocessed_data]
|
| 269 |
|
| 270 |
batch_size = 1
|
| 271 |
length = latent_seq_len
|
| 272 |
with torch.amp.autocast(device):
|
| 273 |
+
conditioning = diffusion_model.conditioner(metadata, device)
|
| 274 |
|
| 275 |
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
| 276 |
+
conditioning['metaclip_features'][~video_exist] = diffusion_model.model.model.empty_clip_feat
|
| 277 |
+
conditioning['sync_features'][~video_exist] = diffusion_model.model.model.empty_sync_feat
|
| 278 |
|
| 279 |
yield "⏳ Inferring…", None
|
| 280 |
|
| 281 |
+
cond_inputs = diffusion_model.get_conditioning_inputs(conditioning)
|
| 282 |
+
noise = torch.randn([batch_size, diffusion_model.io_channels, length]).to(device)
|
| 283 |
with torch.amp.autocast(device):
|
| 284 |
+
if diffusion_model.diffusion_objective == "v":
|
| 285 |
+
fakes = sample(diffusion_model.model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
| 286 |
+
elif diffusion_model.diffusion_objective == "rectified_flow":
|
|
|
|
| 287 |
import time
|
| 288 |
start_time = time.time()
|
| 289 |
+
fakes = sample_discrete_euler(diffusion_model.model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
| 290 |
end_time = time.time()
|
| 291 |
execution_time = end_time - start_time
|
| 292 |
+
print(f"execution_time: {execution_time:.2f} 秒")
|
| 293 |
+
|
| 294 |
+
if diffusion_model.pretransform is not None:
|
| 295 |
+
fakes = diffusion_model.pretransform.decode(fakes)
|
| 296 |
|
| 297 |
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 298 |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|