Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| import os | |
| import numpy as np | |
| import random | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| sys.path.append('..') | |
| try: | |
| import data | |
| except: | |
| import foleycrafter.models.specvqgan.onset_baseline.data | |
| # ---------------------------------------------------- # | |
| def load_model(cp_path, net, device=None, strict=True): | |
| if not device: | |
| device = torch.device('cpu') | |
| if os.path.isfile(cp_path): | |
| print("=> loading checkpoint '{}'".format(cp_path)) | |
| checkpoint = torch.load(cp_path, map_location=device) | |
| # check if there is module | |
| if list(checkpoint['state_dict'].keys())[0][:7] == 'module.': | |
| state_dict = OrderedDict() | |
| for k, v in checkpoint['state_dict'].items(): | |
| name = k[7:] | |
| state_dict[name] = v | |
| else: | |
| state_dict = checkpoint['state_dict'] | |
| net.load_state_dict(state_dict, strict=strict) | |
| print("=> loaded checkpoint '{}' (epoch {})" | |
| .format(cp_path, checkpoint['epoch'])) | |
| start_epoch = checkpoint['epoch'] | |
| else: | |
| print("=> no checkpoint found at '{}'".format(cp_path)) | |
| start_epoch = 0 | |
| sys.exit() | |
| return net, start_epoch | |
| # ---------------------------------------------------- # | |
| def binary_acc(pred, target, thred): | |
| pred = pred > thred | |
| acc = np.sum(pred == target) / target.shape[0] | |
| return acc | |
| def calc_acc(prob, labels, k): | |
| pred = torch.argsort(prob, dim=-1, descending=True)[..., :k] | |
| top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0) | |
| return top_k_acc | |
| # ---------------------------------------------------- # | |
| def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None): | |
| data_loader = getattr(data, pr.dataloader) | |
| if split == 'train': | |
| read_list = pr.list_train | |
| elif split == 'val': | |
| read_list = pr.list_val | |
| elif split == 'test': | |
| read_list = pr.list_test | |
| dataset = data_loader(args, pr, read_list, split=split) | |
| batch_size = batch_size if batch_size else args.batch_size | |
| dataset.getitem_test(1) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=shuffle, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=drop_last) | |
| return dataset, loader | |
| # ---------------------------------------------------- # | |
| def make_optimizer(model, args): | |
| ''' | |
| Args: | |
| model: NN to train | |
| Returns: | |
| optimizer: pytorch optmizer for updating the given model parameters. | |
| ''' | |
| if args.optim == 'SGD': | |
| optimizer = torch.optim.SGD( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=args.lr, | |
| momentum=args.momentum, | |
| weight_decay=args.weight_decay, | |
| nesterov=False | |
| ) | |
| elif args.optim == 'Adam': | |
| optimizer = torch.optim.Adam( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| return optimizer | |
| def adjust_learning_rate(optimizer, epoch, args): | |
| """Decay the learning rate based on schedule""" | |
| lr = args.lr | |
| if args.schedule == 'cos': # cosine lr schedule | |
| lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs)) | |
| elif args.schedule == 'none': # no lr schedule | |
| lr = args.lr | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr |