gavinyuan
add: PIPNet, arcface
b9be4e6
raw
history blame
8.49 kB
import cv2, os
import sys
sys.path.insert(0, '..')
import numpy as np
from PIL import Image
import logging
import copy
import importlib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from networks import *
import data_utils
from functions import *
from mobilenetv3 import mobilenetv3_large
if not len(sys.argv) == 2:
print('Format:')
print('python lib/train.py config_file')
exit(0)
experiment_name = sys.argv[1].split('/')[-1][:-3]
data_name = sys.argv[1].split('/')[-2]
config_path = '.experiments.{}.{}'.format(data_name, experiment_name)
my_config = importlib.import_module(config_path, package='PIPNet')
Config = getattr(my_config, 'Config')
cfg = Config()
cfg.experiment_name = experiment_name
cfg.data_name = data_name
os.environ['CUDA_VISIBLE_DEVICES'] = str(cfg.gpu_id)
if not os.path.exists(os.path.join('./snapshots', cfg.data_name)):
os.mkdir(os.path.join('./snapshots', cfg.data_name))
save_dir = os.path.join('./snapshots', cfg.data_name, cfg.experiment_name)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
if not os.path.exists(os.path.join('./logs', cfg.data_name)):
os.mkdir(os.path.join('./logs', cfg.data_name))
log_dir = os.path.join('./logs', cfg.data_name, cfg.experiment_name)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logging.basicConfig(filename=os.path.join(log_dir, 'train.log'), level=logging.INFO)
print('###########################################')
print('experiment_name:', cfg.experiment_name)
print('data_name:', cfg.data_name)
print('det_head:', cfg.det_head)
print('net_stride:', cfg.net_stride)
print('batch_size:', cfg.batch_size)
print('init_lr:', cfg.init_lr)
print('num_epochs:', cfg.num_epochs)
print('decay_steps:', cfg.decay_steps)
print('input_size:', cfg.input_size)
print('backbone:', cfg.backbone)
print('pretrained:', cfg.pretrained)
print('criterion_cls:', cfg.criterion_cls)
print('criterion_reg:', cfg.criterion_reg)
print('cls_loss_weight:', cfg.cls_loss_weight)
print('reg_loss_weight:', cfg.reg_loss_weight)
print('num_lms:', cfg.num_lms)
print('save_interval:', cfg.save_interval)
print('num_nb:', cfg.num_nb)
print('use_gpu:', cfg.use_gpu)
print('gpu_id:', cfg.gpu_id)
print('###########################################')
logging.info('###########################################')
logging.info('experiment_name: {}'.format(cfg.experiment_name))
logging.info('data_name: {}'.format(cfg.data_name))
logging.info('det_head: {}'.format(cfg.det_head))
logging.info('net_stride: {}'.format(cfg.net_stride))
logging.info('batch_size: {}'.format(cfg.batch_size))
logging.info('init_lr: {}'.format(cfg.init_lr))
logging.info('num_epochs: {}'.format(cfg.num_epochs))
logging.info('decay_steps: {}'.format(cfg.decay_steps))
logging.info('input_size: {}'.format(cfg.input_size))
logging.info('backbone: {}'.format(cfg.backbone))
logging.info('pretrained: {}'.format(cfg.pretrained))
logging.info('criterion_cls: {}'.format(cfg.criterion_cls))
logging.info('criterion_reg: {}'.format(cfg.criterion_reg))
logging.info('cls_loss_weight: {}'.format(cfg.cls_loss_weight))
logging.info('reg_loss_weight: {}'.format(cfg.reg_loss_weight))
logging.info('num_lms: {}'.format(cfg.num_lms))
logging.info('save_interval: {}'.format(cfg.save_interval))
logging.info('num_nb: {}'.format(cfg.num_nb))
logging.info('use_gpu: {}'.format(cfg.use_gpu))
logging.info('gpu_id: {}'.format(cfg.gpu_id))
logging.info('###########################################')
if cfg.det_head == 'pip':
meanface_indices, _, _, _ = get_meanface(os.path.join('data', cfg.data_name, 'meanface.txt'), cfg.num_nb)
if cfg.det_head == 'pip':
if cfg.backbone == 'resnet18':
resnet18 = models.resnet18(pretrained=cfg.pretrained)
net = Pip_resnet18(resnet18, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
elif cfg.backbone == 'resnet50':
resnet50 = models.resnet50(pretrained=cfg.pretrained)
net = Pip_resnet50(resnet50, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
elif cfg.backbone == 'resnet101':
resnet101 = models.resnet101(pretrained=cfg.pretrained)
net = Pip_resnet101(resnet101, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
elif cfg.backbone == 'mobilenet_v2':
mbnet = models.mobilenet_v2(pretrained=cfg.pretrained)
net = Pip_mbnetv2(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
elif cfg.backbone == 'mobilenet_v3':
mbnet = mobilenetv3_large()
if cfg.pretrained:
mbnet.load_state_dict(torch.load('lib/mobilenetv3-large-1cd25616.pth'))
net = Pip_mbnetv3(mbnet, cfg.num_nb, num_lms=cfg.num_lms, input_size=cfg.input_size, net_stride=cfg.net_stride)
else:
print('No such backbone!')
exit(0)
else:
print('No such head:', cfg.det_head)
exit(0)
if cfg.use_gpu:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
device = torch.device("cpu")
net = net.to(device)
criterion_cls = None
if cfg.criterion_cls == 'l2':
criterion_cls = nn.MSELoss()
elif cfg.criterion_cls == 'l1':
criterion_cls = nn.L1Loss()
else:
print('No such cls criterion:', cfg.criterion_cls)
criterion_reg = None
if cfg.criterion_reg == 'l1':
criterion_reg = nn.L1Loss()
elif cfg.criterion_reg == 'l2':
criterion_reg = nn.MSELoss()
else:
print('No such reg criterion:', cfg.criterion_reg)
points_flip = None
if cfg.data_name == 'data_300W':
points_flip = [17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 28, 29, 30, 31, 36, 35, 34, 33, 32, 46, 45, 44, 43, 48, 47, 40, 39, 38, 37, 42, 41, 55, 54, 53, 52, 51, 50, 49, 60, 59, 58, 57, 56, 65, 64, 63, 62, 61, 68, 67, 66]
points_flip = (np.array(points_flip)-1).tolist()
assert len(points_flip) == 68
elif cfg.data_name == 'WFLW':
points_flip = [32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 46, 45, 44, 43, 42, 50, 49, 48, 47, 37, 36, 35, 34, 33, 41, 40, 39, 38, 51, 52, 53, 54, 59, 58, 57, 56, 55, 72, 71, 70, 69, 68, 75, 74, 73, 64, 63, 62, 61, 60, 67, 66, 65, 82, 81, 80, 79, 78, 77, 76, 87, 86, 85, 84, 83, 92, 91, 90, 89, 88, 95, 94, 93, 97, 96]
assert len(points_flip) == 98
elif cfg.data_name == 'COFW':
points_flip = [2, 1, 4, 3, 7, 8, 5, 6, 10, 9, 12, 11, 15, 16, 13, 14, 18, 17, 20, 19, 21, 22, 24, 23, 25, 26, 27, 28, 29]
points_flip = (np.array(points_flip)-1).tolist()
assert len(points_flip) == 29
elif cfg.data_name == 'AFLW':
points_flip = [6, 5, 4, 3, 2, 1, 12, 11, 10, 9, 8, 7, 15, 14, 13, 18, 17, 16, 19]
points_flip = (np.array(points_flip)-1).tolist()
assert len(points_flip) == 19
else:
print('No such data!')
exit(0)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if cfg.pretrained:
optimizer = optim.Adam(net.parameters(), lr=cfg.init_lr)
else:
optimizer = optim.Adam(net.parameters(), lr=cfg.init_lr, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.decay_steps, gamma=0.1)
labels = get_label(cfg.data_name, 'train.txt')
if cfg.det_head == 'pip':
train_data = data_utils.ImageFolder_pip(os.path.join('data', cfg.data_name, 'images_train'),
labels, cfg.input_size, cfg.num_lms,
cfg.net_stride, points_flip, meanface_indices,
transforms.Compose([
transforms.RandomGrayscale(0.2),
transforms.ToTensor(),
normalize]))
else:
print('No such head:', cfg.det_head)
exit(0)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=cfg.batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
train_model(cfg.det_head, net, train_loader, criterion_cls, criterion_reg, cfg.cls_loss_weight, cfg.reg_loss_weight, cfg.num_nb, optimizer, cfg.num_epochs, scheduler, save_dir, cfg.save_interval, device)