wizzseen's picture
Upload 948 files
8a6df40 verified
import socket
import timeit
from datetime import datetime
import os
import sys
import glob
import numpy as np
from collections import OrderedDict
sys.path.append('../../')
sys.path.append('../../networks/')
# PyTorch includes
import torch
from torch.autograd import Variable
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
# Tensorboard include
from tensorboardX import SummaryWriter
# Custom includes
from dataloaders import cihp
from utils import util,get_iou_from_list
from networks import deeplab_xception_transfer, graph
from dataloaders import custom_transforms as tr
#
import argparse
gpu_id = 0
nEpochs = 100 # Number of epochs for training
resume_epoch = 0 # Default is 0, change if want to resume
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
def flip_cihp(tail_list):
'''
:param tail_list: tail_list size is 1 x n_class x h x w
:return:
'''
# tail_list = tail_list[0]
tail_list_rev = [None] * 20
for xx in range(14):
tail_list_rev[xx] = tail_list[xx].unsqueeze(0)
tail_list_rev[14] = tail_list[15].unsqueeze(0)
tail_list_rev[15] = tail_list[14].unsqueeze(0)
tail_list_rev[16] = tail_list[17].unsqueeze(0)
tail_list_rev[17] = tail_list[16].unsqueeze(0)
tail_list_rev[18] = tail_list[19].unsqueeze(0)
tail_list_rev[19] = tail_list[18].unsqueeze(0)
return torch.cat(tail_list_rev,dim=0)
def get_parser():
'''argparse begin'''
parser = argparse.ArgumentParser()
LookupChoices = type('', (argparse.Action,), dict(__call__=lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--batch', default=16, type=int)
parser.add_argument('--lr', default=1e-7, type=float)
parser.add_argument('--numworker',default=12,type=int)
parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
parser.add_argument('--step', default=10, type=int)
parser.add_argument('--classes', default=20, type=int)
parser.add_argument('--testInterval', default=10, type=int)
parser.add_argument('--loadmodel',default='',type=str)
parser.add_argument('--pretrainedModel', default='', type=str)
parser.add_argument('--hidden_layers',default=128,type=int)
parser.add_argument('--gpus',default=4, type=int)
opts = parser.parse_args()
return opts
def get_graphs(opts):
adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
adj2 = adj2_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20).transpose(2, 3).cuda()
adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 20).transpose(2, 3)
adj1_ = Variable(torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
adj3 = adj1_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 7).cuda()
adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7)
# adj2 = torch.from_numpy(graph.cihp2pascal_adj).float()
# adj2 = adj2.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 7, 20)
cihp_adj = graph.preprocess_adj(graph.cihp_graph)
adj3_ = Variable(torch.from_numpy(cihp_adj).float())
adj1 = adj3_.unsqueeze(0).unsqueeze(0).expand(opts.gpus, 1, 20, 20).cuda()
adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20)
train_graph = [adj1, adj2, adj3]
test_graph = [adj1_test, adj2_test, adj3_test]
return train_graph, test_graph
def val_cihp(net_, testloader, testloader_flip, test_graph, epoch, writer, criterion, classes=20):
adj1_test, adj2_test, adj3_test = test_graph
num_img_ts = len(testloader)
net_.eval()
pred_list = []
label_list = []
running_loss_ts = 0.0
miou = 0
for ii, sample_batched in enumerate(zip(testloader, testloader_flip)):
inputs, labels = sample_batched[0]['image'], sample_batched[0]['label']
inputs_f, _ = sample_batched[1]['image'], sample_batched[1]['label']
inputs = torch.cat((inputs, inputs_f), dim=0)
# Forward pass of the mini-batch
inputs, labels = Variable(inputs, requires_grad=False), Variable(labels)
if gpu_id >= 0:
inputs, labels = inputs.cuda(), labels.cuda()
with torch.no_grad():
outputs = net_.forward(inputs, adj1_test.cuda(), adj3_test.cuda(), adj2_test.cuda())
# pdb.set_trace()
outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2
outputs = outputs.unsqueeze(0)
predictions = torch.max(outputs, 1)[1]
pred_list.append(predictions.cpu())
label_list.append(labels.squeeze(1).cpu())
loss = criterion(outputs, labels, batch_average=True)
running_loss_ts += loss.item()
# total_iou += utils.get_iou(predictions, labels)
# Print stuff
if ii % num_img_ts == num_img_ts - 1:
# if ii == 10:
miou = get_iou_from_list(pred_list, label_list, n_cls=classes)
running_loss_ts = running_loss_ts / num_img_ts
print('Validation:')
print('[Epoch: %d, numImages: %5d]' % (epoch, ii * 1 + inputs.data.shape[0]))
writer.add_scalar('data/test_loss_epoch', running_loss_ts, epoch)
writer.add_scalar('data/test_miour', miou, epoch)
print('Loss: %f' % running_loss_ts)
print('MIoU: %f\n' % miou)
def main(opts):
p = OrderedDict() # Parameters to include in report
p['trainBatch'] = opts.batch # Training batch size
testBatch = 1 # Testing batch size
useTest = True # See evolution of the test set when training
nTestInterval = opts.testInterval # Run on test set every nTestInterval epochs
snapshot = 1 # Store a model every snapshot epochs
p['nAveGrad'] = 1 # Average the gradient of several iterations
p['lr'] = opts.lr # Learning rate
p['lrFtr'] = 1e-5
p['lraspp'] = 1e-5
p['lrpro'] = 1e-5
p['lrdecoder'] = 1e-5
p['lrother'] = 1e-5
p['wd'] = 5e-4 # Weight decay
p['momentum'] = 0.9 # Momentum
p['epoch_size'] = opts.step # How many epochs to change learning rate
p['num_workers'] = opts.numworker
model_path = opts.pretrainedModel
backbone = 'xception' # Use xception or resnet as feature extractor,
nEpochs = opts.epochs
max_id = 0
save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
for r in runs:
run_id = int(r.split('_')[-1])
if run_id >= max_id:
max_id = run_id + 1
save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))
# Network definition
if backbone == 'xception':
net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
hidden_layers=opts.hidden_layers, source_classes=7, )
elif backbone == 'resnet':
# net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
raise NotImplementedError
else:
raise NotImplementedError
modelName = 'deeplabv3plus-' + backbone + '-voc'+datetime.now().strftime('%b%d_%H-%M-%S')
criterion = util.cross_entropy2d
if gpu_id >= 0:
# torch.cuda.set_device(device=gpu_id)
net_.cuda()
# net load weights
if not model_path == '':
x = torch.load(model_path)
net_.load_state_dict_new(x)
print('load pretrainedModel:', model_path)
else:
print('no pretrainedModel.')
if not opts.loadmodel =='':
x = torch.load(opts.loadmodel)
net_.load_source_model(x)
print('load model:' ,opts.loadmodel)
else:
print('no model load !!!!!!!!')
log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
writer = SummaryWriter(log_dir=log_dir)
writer.add_text('load model',opts.loadmodel,1)
writer.add_text('setting',sys.argv[0],1)
if opts.freezeBN:
net_.freeze_bn()
# Use the following optimizer
optimizer = optim.SGD(net_.parameters(), lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
composed_transforms_tr = transforms.Compose([
tr.RandomSized_new(512),
tr.Normalize_xception_tf(),
tr.ToTensor_()])
composed_transforms_ts = transforms.Compose([
tr.Normalize_xception_tf(),
tr.ToTensor_()])
composed_transforms_ts_flip = transforms.Compose([
tr.HorizontalFlip(),
tr.Normalize_xception_tf(),
tr.ToTensor_()])
voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
trainloader = DataLoader(voc_train, batch_size=p['trainBatch'], shuffle=True, num_workers=p['num_workers'],drop_last=True)
testloader = DataLoader(voc_val, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
testloader_flip = DataLoader(voc_val_flip, batch_size=testBatch, shuffle=False, num_workers=p['num_workers'])
num_img_tr = len(trainloader)
num_img_ts = len(testloader)
running_loss_tr = 0.0
running_loss_ts = 0.0
aveGrad = 0
global_step = 0
print("Training Network")
net = torch.nn.DataParallel(net_)
train_graph, test_graph = get_graphs(opts)
adj1, adj2, adj3 = train_graph
# Main Training and Testing Loop
for epoch in range(resume_epoch, nEpochs):
start_time = timeit.default_timer()
if epoch % p['epoch_size'] == p['epoch_size'] - 1:
lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
optimizer = optim.SGD(net_.parameters(), lr=lr_, momentum=p['momentum'], weight_decay=p['wd'])
writer.add_scalar('data/lr_', lr_, epoch)
print('(poly lr policy) learning rate: ', lr_)
net.train()
for ii, sample_batched in enumerate(trainloader):
inputs, labels = sample_batched['image'], sample_batched['label']
# Forward-Backward of the mini-batch
inputs, labels = Variable(inputs, requires_grad=True), Variable(labels)
global_step += inputs.data.shape[0]
if gpu_id >= 0:
inputs, labels = inputs.cuda(), labels.cuda()
outputs = net.forward(inputs, adj1, adj3, adj2)
loss = criterion(outputs, labels, batch_average=True)
running_loss_tr += loss.item()
# Print stuff
if ii % num_img_tr == (num_img_tr - 1):
running_loss_tr = running_loss_tr / num_img_tr
writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)
print('[Epoch: %d, numImages: %5d]' % (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
print('Loss: %f' % running_loss_tr)
running_loss_tr = 0
stop_time = timeit.default_timer()
print("Execution time: " + str(stop_time - start_time) + "\n")
# Backward the averaged gradient
loss /= p['nAveGrad']
loss.backward()
aveGrad += 1
# Update the weights once in p['nAveGrad'] forward passes
if aveGrad % p['nAveGrad'] == 0:
writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch)
optimizer.step()
optimizer.zero_grad()
aveGrad = 0
# Show 10 * 3 images results each epoch
if ii % (num_img_tr // 10) == 0:
grid_image = make_grid(inputs[:3].clone().cpu().data, 3, normalize=True)
writer.add_image('Image', grid_image, global_step)
grid_image = make_grid(util.decode_seg_map_sequence(torch.max(outputs[:3], 1)[1].detach().cpu().numpy()), 3, normalize=False,
range=(0, 255))
writer.add_image('Predicted label', grid_image, global_step)
grid_image = make_grid(util.decode_seg_map_sequence(torch.squeeze(labels[:3], 1).detach().cpu().numpy()), 3, normalize=False, range=(0, 255))
writer.add_image('Groundtruth label', grid_image, global_step)
print('loss is ', loss.cpu().item(), flush=True)
# Save the model
if (epoch % snapshot) == snapshot - 1:
torch.save(net_.state_dict(), os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth'))
print("Save model at {}\n".format(os.path.join(save_dir, 'models', modelName + '_epoch-' + str(epoch) + '.pth')))
torch.cuda.empty_cache()
# One testing epoch
if useTest and epoch % nTestInterval == (nTestInterval - 1):
val_cihp(net_,testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
epoch=epoch,writer=writer,criterion=criterion, classes=opts.classes)
torch.cuda.empty_cache()
if __name__ == '__main__':
opts = get_parser()
main(opts)