cheng-hust's picture
Upload 91 files
e8861c0 verified
raw
history blame
4.08 kB
'''
by lyuwenyu
'''
import time
import json
import datetime
import torch
from src.misc import dist
from src.data import get_coco_api_from_dataset
from .solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
class DetSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
# best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, }
best_stat = {'epoch': -1, }
start_time = time.time()
for epoch in range(self.last_epoch + 1, args.epoches):
if dist.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
self.lr_scheduler.step()
if self.output_dir:
checkpoint_paths = [self.output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_step == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)
# TODO
for k in test_stats.keys():
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if self.output_dir and dist.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, base_ds, self.device, self.output_dir)
if self.output_dir:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return