File size: 3,602 Bytes
7f2690b
 
 
 
 
 
 
 
 
 
 
 
81fb07b
 
 
 
7f2690b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from collections import OrderedDict
import os
import numpy as np
import random
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

sys.path.append('..')
try:
    import data
except:
    import foleycrafter.models.specvqgan.onset_baseline.data


# ---------------------------------------------------- #
def load_model(cp_path, net, device=None, strict=True): 
    if not device:
        device = torch.device('cpu')
    if os.path.isfile(cp_path): 
        print("=> loading checkpoint '{}'".format(cp_path))
        checkpoint = torch.load(cp_path, map_location=device)

        # check if there is module
        if list(checkpoint['state_dict'].keys())[0][:7] == 'module.': 
            state_dict = OrderedDict()
            for k, v in checkpoint['state_dict'].items(): 
                name = k[7:]
                state_dict[name] = v
        else: 
            state_dict = checkpoint['state_dict']
        net.load_state_dict(state_dict, strict=strict) 

        print("=> loaded checkpoint '{}' (epoch {})"
                    .format(cp_path, checkpoint['epoch']))
        start_epoch = checkpoint['epoch']
    else: 
        print("=> no checkpoint found at '{}'".format(cp_path))
        start_epoch = 0
        sys.exit()
    
    return net, start_epoch


# ---------------------------------------------------- #
def binary_acc(pred, target, thred):
    pred = pred > thred
    acc = np.sum(pred == target) / target.shape[0]
    return acc

def calc_acc(prob, labels, k):
    pred = torch.argsort(prob, dim=-1, descending=True)[..., :k]
    top_k_acc = torch.sum(pred == labels.view(-1, 1)).float() / labels.size(0)
    return top_k_acc

# ---------------------------------------------------- #

def get_dataloader(args, pr, split='train', shuffle=False, drop_last=False, batch_size=None):
    data_loader = getattr(data, pr.dataloader)
    if split == 'train':
        read_list = pr.list_train
    elif split == 'val':
        read_list = pr.list_val
    elif split == 'test':
        read_list = pr.list_test
    dataset = data_loader(args, pr, read_list, split=split)
    batch_size = batch_size if batch_size else args.batch_size
    dataset.getitem_test(1)
    loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=shuffle, 
        num_workers=args.num_workers, 
        pin_memory=True, 
        drop_last=drop_last)
    
    return dataset, loader


# ---------------------------------------------------- #
def make_optimizer(model, args):
    '''
    Args:
        model: NN to train
    Returns:
        optimizer: pytorch optmizer for updating the given model parameters.
    '''
    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
            nesterov=False
        )
    elif args.optim == 'Adam':
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            weight_decay=args.weight_decay,
        )
    return optimizer


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.schedule == 'cos':  # cosine lr schedule
        lr *= 0.5 * (1. + np.cos(np.pi * epoch / args.epochs))
    elif args.schedule == 'none':  # no lr schedule
        lr = args.lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr