File size: 4,579 Bytes
07e1105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
from scipy import stats
import numpy as np
from models import monet as MoNet
from models import gc_loss as GC_Loss
from utils.dataset import data_loader
import json
import random
import os
from tqdm import tqdm
def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
with open(data_path, 'r') as data_info:
data_info = json.load(data_info)
path, img_num = data_info[dataset]
img_num = list(range(img_num))
random.shuffle(img_num)
train_index = img_num[0:int(round(0.8 * len(img_num)))]
test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]
return path, train_index, test_index
def cal_srocc_plcc(pred_score, gt_score):
srocc, _ = stats.spearmanr(pred_score, gt_score)
plcc, _ = stats.pearsonr(pred_score, gt_score)
return srocc, plcc
class Solver:
def __init__(self, config):
path, train_index, test_index = get_data(dataset=config.dataset)
train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
self.train_data = train_loader.get_data()
self.test_data = test_loader.get_data()
print('Traning data number: ', len(train_index))
print('Testing data number: ', len(test_index))
if config.loss == 'MAE':
self.loss = torch.nn.L1Loss().cuda()
elif config.loss == 'MSE':
self.loss = torch.nn.MSELoss().cuda()
elif config.loss == 'GC':
self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
else:
raise 'Only Support MAE, MSE and GC loss.'
print('Loading MoNet...')
self.MoNet = MoNet.MoNet(config).cuda()
self.MoNet.train(True)
self.epochs = config.epochs
self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)
self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')
def train(self):
"""Training"""
best_srocc = 0.0
best_plcc = 0.0
print('----------------------------------')
print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
for t in range(self.epochs):
epoch_loss = []
pred_scores = []
gt_scores = []
for img, label in tqdm(self.train_data):
img = img.cuda()
label = label.view(-1).cuda()
self.optimizer.zero_grad()
pred = self.MoNet(img) # 'paras' contains the network weights conveyed to target network
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
loss = self.loss(pred.squeeze(), label.float().detach())
epoch_loss.append(loss.item())
loss.backward()
self.optimizer.step()
self.scheduler.step()
train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)
test_srocc, test_plcc = self.test()
if test_srocc + test_plcc > best_srocc + best_plcc:
best_srocc = test_srocc
best_plcc = test_plcc
torch.save(self.MoNet.state_dict(), self.model_save_path)
print('Model saved in: ', self.model_save_path)
print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))
print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))
return best_srocc, best_plcc
def test(self):
"""Testing"""
self.MoNet.train(False)
pred_scores = []
gt_scores = []
with torch.no_grad():
for img, label in tqdm(self.test_data):
# Data.
img = img.cuda()
label = label.view(-1).cuda()
pred = self.MoNet(img)
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)
self.MoNet.train(True)
return test_srocc, test_plcc
|