File size: 4,579 Bytes
07e1105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from scipy import stats
import numpy as np
from models import monet as MoNet
from models import gc_loss as GC_Loss
from utils.dataset import data_loader
import json
import random
import os
from tqdm import tqdm


def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
    with open(data_path, 'r') as data_info:
        data_info = json.load(data_info)
    path, img_num = data_info[dataset]
    img_num = list(range(img_num))

    random.shuffle(img_num)
    train_index = img_num[0:int(round(0.8 * len(img_num)))]
    test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]

    return path, train_index, test_index


def cal_srocc_plcc(pred_score, gt_score):
    srocc, _ = stats.spearmanr(pred_score, gt_score)
    plcc, _ = stats.pearsonr(pred_score, gt_score)

    return srocc, plcc


class Solver:
    def __init__(self, config):

        path, train_index, test_index = get_data(dataset=config.dataset)

        train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
        test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
        self.train_data = train_loader.get_data()
        self.test_data = test_loader.get_data()

        print('Traning data number: ', len(train_index))
        print('Testing data number: ', len(test_index))
        
        if config.loss == 'MAE':
            self.loss = torch.nn.L1Loss().cuda()
        elif config.loss == 'MSE':
            self.loss = torch.nn.MSELoss().cuda()
        elif config.loss == 'GC':
            self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
        else:
            raise 'Only Support MAE, MSE and GC loss.'

        print('Loading MoNet...')
        self.MoNet = MoNet.MoNet(config).cuda()
        self.MoNet.train(True)

        self.epochs = config.epochs
        self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)

        self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')

    def train(self):
        """Training"""
        best_srocc = 0.0
        best_plcc = 0.0
        print('----------------------------------')
        print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
        for t in range(self.epochs):
            epoch_loss = []
            pred_scores = []
            gt_scores = []

            for img, label in tqdm(self.train_data):
                img = img.cuda()
                label = label.view(-1).cuda()

                self.optimizer.zero_grad()

                pred = self.MoNet(img)  # 'paras' contains the network weights conveyed to target network

                pred_scores = pred_scores + pred.cpu().tolist()
                gt_scores = gt_scores + label.cpu().tolist()

                loss = self.loss(pred.squeeze(), label.float().detach())
                epoch_loss.append(loss.item())
                
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

            train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)

            test_srocc, test_plcc = self.test()
            if test_srocc + test_plcc > best_srocc + best_plcc:
                best_srocc = test_srocc
                best_plcc = test_plcc
                torch.save(self.MoNet.state_dict(), self.model_save_path)
                print('Model saved in: ', self.model_save_path)

            print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
                                                  round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))

        print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))

        return best_srocc, best_plcc

    def test(self):
        """Testing"""
        self.MoNet.train(False)
        pred_scores = []
        gt_scores = []

        with torch.no_grad():
            for img, label in tqdm(self.test_data):
                # Data.
                img = img.cuda()
                label = label.view(-1).cuda()  

                pred = self.MoNet(img)

                pred_scores = pred_scores + pred.cpu().tolist()
                gt_scores = gt_scores + label.cpu().tolist()

        test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)

        self.MoNet.train(True)
        return test_srocc, test_plcc