Spaces:
Running
Running
# ----------------------------------------------------- | |
# Copyright (c) Shanghai Jiao Tong University. All rights reserved. | |
# Written by Jiefeng Li ([email protected]) | |
# ----------------------------------------------------- | |
import torch | |
import torch.utils.data | |
from .utils.dataset import coco | |
from opt import opt | |
from tqdm import tqdm | |
from models.FastPose import createModel | |
from .utils.eval import DataLogger, accuracy | |
from .utils.img import flip, shuffleLR | |
from .evaluation import prediction | |
from tensorboardX import SummaryWriter | |
import os | |
def train(train_loader, m, criterion, optimizer, writer): | |
lossLogger = DataLogger() | |
accLogger = DataLogger() | |
m.train() | |
train_loader_desc = tqdm(train_loader) | |
for i, (inps, labels, setMask, imgset) in enumerate(train_loader_desc): | |
inps = inps.requires_grad_() | |
labels = labels | |
setMask = setMask | |
out = m(inps) | |
loss = criterion(out.mul(setMask), labels) | |
acc = accuracy(out.data.mul(setMask), labels.data, train_loader.dataset) | |
accLogger.update(acc[0], inps.size(0)) | |
lossLogger.update(loss.item(), inps.size(0)) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
opt.trainIters += 1 | |
# Tensorboard | |
writer.add_scalar( | |
'Train/Loss', lossLogger.avg, opt.trainIters) | |
writer.add_scalar( | |
'Train/Acc', accLogger.avg, opt.trainIters) | |
# TQDM | |
train_loader_desc.set_description( | |
'loss: {loss:.8f} | acc: {acc:.2f}'.format( | |
loss=lossLogger.avg, | |
acc=accLogger.avg * 100) | |
) | |
train_loader_desc.close() | |
return lossLogger.avg, accLogger.avg | |
def valid(val_loader, m, criterion, optimizer, writer): | |
lossLogger = DataLogger() | |
accLogger = DataLogger() | |
m.eval() | |
val_loader_desc = tqdm(val_loader) | |
for i, (inps, labels, setMask, imgset) in enumerate(val_loader_desc): | |
inps = inps | |
labels = labels | |
setMask = setMask | |
with torch.no_grad(): | |
out = m(inps) | |
loss = criterion(out.mul(setMask), labels) | |
flip_out = m(flip(inps)) | |
flip_out = flip(shuffleLR(flip_out, val_loader.dataset)) | |
out = (flip_out + out) / 2 | |
acc = accuracy(out.mul(setMask), labels, val_loader.dataset) | |
lossLogger.update(loss.item(), inps.size(0)) | |
accLogger.update(acc[0], inps.size(0)) | |
opt.valIters += 1 | |
# Tensorboard | |
writer.add_scalar( | |
'Valid/Loss', lossLogger.avg, opt.valIters) | |
writer.add_scalar( | |
'Valid/Acc', accLogger.avg, opt.valIters) | |
val_loader_desc.set_description( | |
'loss: {loss:.8f} | acc: {acc:.2f}'.format( | |
loss=lossLogger.avg, | |
acc=accLogger.avg * 100) | |
) | |
val_loader_desc.close() | |
return lossLogger.avg, accLogger.avg | |
def main(): | |
# Model Initialize | |
m = createModel() | |
if opt.loadModel: | |
print('Loading Model from {}'.format(opt.loadModel)) | |
m.load_state_dict(torch.load(opt.loadModel, map_location=torch.device('cpu'))) | |
if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)): | |
try: | |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID)) | |
except FileNotFoundError: | |
os.mkdir("../exp/{}".format(opt.dataset)) | |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID)) | |
else: | |
print('Create new model') | |
if not os.path.exists("../exp/{}/{}".format(opt.dataset, opt.expID)): | |
try: | |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID)) | |
except FileNotFoundError: | |
os.mkdir("../exp/{}".format(opt.dataset)) | |
os.mkdir("../exp/{}/{}".format(opt.dataset, opt.expID)) | |
criterion = torch.nn.MSELoss() | |
if opt.optMethod == 'rmsprop': | |
optimizer = torch.optim.RMSprop(m.parameters(), | |
lr=opt.LR, | |
momentum=opt.momentum, | |
weight_decay=opt.weightDecay) | |
elif opt.optMethod == 'adam': | |
optimizer = torch.optim.Adam( | |
m.parameters(), | |
lr=opt.LR | |
) | |
else: | |
raise Exception | |
writer = SummaryWriter( | |
'.tensorboard/{}/{}'.format(opt.dataset, opt.expID)) | |
# Prepare Dataset | |
if opt.dataset == 'coco': | |
train_dataset = coco.Mscoco(train=True) | |
val_dataset = coco.Mscoco(train=False) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=opt.trainBatch, shuffle=True, num_workers=opt.nThreads, pin_memory=True) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, batch_size=opt.validBatch, shuffle=False, num_workers=opt.nThreads, pin_memory=True) | |
# Model Transfer | |
m = torch.nn.DataParallel(m) | |
# Start Training | |
for i in range(opt.nEpochs): | |
opt.epoch = i | |
print('############# Starting Epoch {} #############'.format(opt.epoch)) | |
loss, acc = train(train_loader, m, criterion, optimizer, writer) | |
print('Train-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format( | |
idx=opt.epoch, | |
loss=loss, | |
acc=acc | |
)) | |
opt.acc = acc | |
opt.loss = loss | |
m_dev = m.module | |
if i % opt.snapshot == 0: | |
torch.save( | |
m_dev.state_dict(), '../exp/{}/{}/model_{}.pkl'.format(opt.dataset, opt.expID, opt.epoch)) | |
torch.save( | |
opt, '../exp/{}/{}/option.pkl'.format(opt.dataset, opt.expID, opt.epoch)) | |
torch.save( | |
optimizer, '../exp/{}/{}/optimizer.pkl'.format(opt.dataset, opt.expID)) | |
loss, acc = valid(val_loader, m, criterion, optimizer, writer) | |
print('Valid-{idx:d} epoch | loss:{loss:.8f} | acc:{acc:.4f}'.format( | |
idx=i, | |
loss=loss, | |
acc=acc | |
)) | |
''' | |
if opt.dataset != 'mpii': | |
with torch.no_grad(): | |
mAP, mAP5 = prediction(m) | |
print('Prediction-{idx:d} epoch | mAP:{mAP:.3f} | mAP0.5:{mAP5:.3f}'.format( | |
idx=i, | |
mAP=mAP, | |
mAP5=mAP5 | |
)) | |
''' | |
writer.close() | |
if __name__ == '__main__': | |
main() | |