Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,602 Bytes
7f2690b 81fb07b 7f2690b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from collections import OrderedDict
import os
import numpy as np
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
sys.path.append('..')
try:
import data
except:
import foleycrafter.models.specvqgan.onset_baseline.data
# ---------------------------------------------------- #
def load_model(cp_path, net, device=None, strict=True):
if not device:
device = torch.device('cpu')
if os.path.isfile(cp_path):
print("=> loading checkpoint '{}'".format(cp_path))
checkpoint = torch.load(cp_path, map_location=device)
# check if there is module
if list(checkpoint['state_dict'].keys())[0][:7] == 'module.':
state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
name = k[7:]
state_dict[name] = v
else:
state_dict = checkpoint['state_dict']
net.load_state_dict(state_dict, strict=strict)
print("=> loaded checkpoint '{}' (epoch {})"
.format(cp_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch']
else:
print("=> no checkpoint found at '{}'".format(cp_path))
start_epoch = 0
sys.exit()
return net, start_epoch
# ---------------------------------------------------- #
def binary_acc(pred, target, thred):
pred = pred > thred
acc = np.sum(pred == target) / target.shape[0]
return acc
def calc_acc(prob, labels, k):
pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)
return top_k_acc
# ---------------------------------------------------- #
def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None):
data_loader = getattr(data, pr.dataloader)
if split == 'train':
read_list = pr.list_train
elif split == 'val':
read_list = pr.list_val
elif split == 'test':
read_list = pr.list_test
dataset = data_loader(args, pr, read_list, split=split)
batch_size = batch_size if batch_size else args.batch_size
dataset.getitem_test(1)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=True,
drop_last=drop_last)
return dataset, loader
# ---------------------------------------------------- #
def make_optimizer(model, args):
'''
Args:
model: NN to train
Returns:
optimizer: pytorch optmizer for updating the given model parameters.
'''
if args.optim == 'SGD':
optimizer = torch.optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=False
)
elif args.optim == 'Adam':
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr,
weight_decay=args.weight_decay,
)
return optimizer
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate based on schedule"""
lr = args.lr
if args.schedule == 'cos': # cosine lr schedule
lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs))
elif args.schedule == 'none': # no lr schedule
lr = args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr |