Spaces:
Configuration error
Configuration error
import time | |
import datetime | |
import torch | |
import tqdm | |
from torch.nn import DataParallel | |
from lib.config import cfg | |
class Trainer(object): | |
def __init__(self, network): | |
device = torch.device('cuda:{}'.format(cfg.local_rank)) | |
network = network.to(device) | |
if cfg.distributed: | |
network = torch.nn.parallel.DistributedDataParallel( | |
network, | |
device_ids=[cfg.local_rank], | |
output_device=cfg.local_rank | |
) | |
self.network = network | |
self.local_rank = cfg.local_rank | |
self.device = device | |
def reduce_loss_stats(self, loss_stats): | |
reduced_losses = {k: torch.mean(v) for k, v in loss_stats.items()} | |
return reduced_losses | |
def to_cuda(self, batch): | |
for k in batch: | |
if k == 'meta': | |
continue | |
if isinstance(batch[k], tuple) or isinstance(batch[k], list): | |
batch[k] = [b.to(self.device) for b in batch[k]] | |
else: | |
batch[k] = batch[k].to(self.device) | |
return batch | |
def train(self, epoch, data_loader, optimizer, recorder): | |
max_iter = len(data_loader) | |
self.network.train() | |
end = time.time() | |
for iteration, batch in enumerate(data_loader): | |
data_time = time.time() - end | |
iteration = iteration + 1 | |
batch = self.to_cuda(batch) | |
output, loss, loss_stats, image_stats = self.network(batch) | |
# training stage: loss; optimizer; scheduler | |
optimizer.zero_grad() | |
loss = loss.mean() | |
loss.backward() | |
torch.nn.utils.clip_grad_value_(self.network.parameters(), 40) | |
optimizer.step() | |
if cfg.local_rank > 0: | |
continue | |
# data recording stage: loss_stats, time, image_stats | |
recorder.step += 1 | |
loss_stats = self.reduce_loss_stats(loss_stats) | |
recorder.update_loss_stats(loss_stats) | |
batch_time = time.time() - end | |
end = time.time() | |
recorder.batch_time.update(batch_time) | |
recorder.data_time.update(data_time) | |
if iteration % cfg.log_interval == 0 or iteration == (max_iter - 1): | |
# print training state | |
eta_seconds = recorder.batch_time.global_avg * (max_iter - iteration) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
lr = optimizer.param_groups[0]['lr'] | |
memory = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 | |
training_state = ' '.join(['eta: {}', '{}', 'lr: {:.6f}', 'max_mem: {:.0f}']) | |
training_state = training_state.format(eta_string, str(recorder), lr, memory) | |
print(training_state) | |
if iteration % cfg.record_interval == 0 or iteration == (max_iter - 1): | |
# record loss_stats and image_dict | |
recorder.update_image_stats(image_stats) | |
recorder.record('train') | |
def val(self, epoch, data_loader, evaluator=None, recorder=None): | |
self.network.eval() | |
torch.cuda.empty_cache() | |
val_loss_stats = {} | |
data_size = len(data_loader) | |
for batch in tqdm.tqdm(data_loader): | |
batch = self.to_cuda(batch) | |
with torch.no_grad(): | |
output, loss, loss_stats, image_stats = self.network(batch) | |
if evaluator is not None: | |
evaluator.evaluate(output, batch) | |
loss_stats = self.reduce_loss_stats(loss_stats) | |
for k, v in loss_stats.items(): | |
val_loss_stats.setdefault(k, 0) | |
val_loss_stats[k] += v | |
loss_state = [] | |
for k in val_loss_stats.keys(): | |
val_loss_stats[k] /= data_size | |
loss_state.append('{}: {:.4f}'.format(k, val_loss_stats[k])) | |
print(loss_state) | |
if evaluator is not None: | |
result = evaluator.summarize() | |
val_loss_stats.update(result) | |
if recorder: | |
recorder.record('val', epoch, val_loss_stats, image_stats) | |