import gc import torch as t def freeze_model(model): model.eval() for params in model.parameters(): params.requires_grad = False def unfreeze_model(model): model.train() for params in model.parameters(): params.requires_grad = True def zero_grad(model): for p in model.parameters(): if p.requires_grad and p.grad is not None: p.grad = None def empty_cache(): gc.collect() t.cuda.empty_cache() def assert_shape(x, exp_shape): assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def count_state(model): return sum(s.numel() for s in model.state_dict().values()) import argparse def parse_args(): parser = argparse.ArgumentParser(description='Codebook') parser.add_argument('--config', default='./configs/codebook.yml') parser.add_argument('--gpu', type=str, default='2') parser.add_argument('--no_cuda', type=list, default=['2']) parser.add_argument('--prefix', type=str, required=False, default='knn_pred_wavvq') parser.add_argument('--save_path', type=str, required=False, default="./Speech2GestureMatching/output/") parser.add_argument('--code_path', type=str, required=False) parser.add_argument('--VQVAE_model_path', type=str, required=False) parser.add_argument('--BEAT_path', type=str, default="../dataset/orig_BEAT/speakers/") parser.add_argument('--save_dir', type=str, default="../dataset/BEAT") parser.add_argument('--step', type=str, default="1") parser.add_argument('--stage', type=str, default="train") args = parser.parse_args() return args