pengsida
initial commit
1ba539f
raw
history blame
4.17 kB
import time
import datetime
import torch
import tqdm
from torch.nn import DataParallel
from lib.config import cfg
class Trainer(object):
def __init__(self, network):
device = torch.device('cuda:{}'.format(cfg.local_rank))
network = network.to(device)
if cfg.distributed:
network = torch.nn.parallel.DistributedDataParallel(
network,
device_ids=[cfg.local_rank],
output_device=cfg.local_rank
)
self.network = network
self.local_rank = cfg.local_rank
self.device = device
def reduce_loss_stats(self, loss_stats):
reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()}
return reduced_losses
def to_cuda(self, batch):
for k in batch:
if k == 'meta':
continue
if isinstance(batch[k], tuple) or isinstance(batch[k], list):
batch[k] = [b.to(self.device) for b in batch[k]]
else:
batch[k] = batch[k].to(self.device)
return batch
def train(self, epoch, data_loader, optimizer, recorder):
max_iter = len(data_loader)
self.network.train()
end = time.time()
for iteration, batch in enumerate(data_loader):
data_time = time.time() - end
iteration = iteration + 1
batch = self.to_cuda(batch)
output, loss, loss_stats, image_stats = self.network(batch)
# training stage: loss; optimizer; scheduler
optimizer.zero_grad()
loss = loss.mean()
loss.backward()
torch.nn.utils.clip_grad_value_(self.network.parameters(), 40)
optimizer.step()
if cfg.local_rank > 0:
continue
# data recording stage: loss_stats, time, image_stats
recorder.step += 1
loss_stats = self.reduce_loss_stats(loss_stats)
recorder.update_loss_stats(loss_stats)
batch_time = time.time() - end
end = time.time()
recorder.batch_time.update(batch_time)
recorder.data_time.update(data_time)
if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1):
# print training state
eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
lr = optimizer.param_groups[0]['lr']
memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}'])
training_state = training_state.format(eta_string, str(recorder), lr, memory)
print(training_state)
if iteration % cfg.record_interval == 0 or iteration == (max_iter - 1):
# record loss_stats and image_dict
recorder.update_image_stats(image_stats)
recorder.record('train')
def val(self, epoch, data_loader, evaluator=None, recorder=None):
self.network.eval()
torch.cuda.empty_cache()
val_loss_stats = {}
data_size = len(data_loader)
for batch in tqdm.tqdm(data_loader):
batch = self.to_cuda(batch)
with torch.no_grad():
output, loss, loss_stats, image_stats = self.network(batch)
if evaluator is not None:
evaluator.evaluate(output, batch)
loss_stats = self.reduce_loss_stats(loss_stats)
for k, v in loss_stats.items():
val_loss_stats.setdefault(k, 0)
val_loss_stats[k] += v
loss_state = []
for k in val_loss_stats.keys():
val_loss_stats[k] /= data_size
loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k]))
print(loss_state)
if evaluator is not None:
result = evaluator.summarize()
val_loss_stats.update(result)
if recorder:
recorder.record('val', epoch, val_loss_stats, image_stats)