#from robustness.robustness.tools.helpers https://github.com/MadryLab/robustness class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class VariableLossLogPrinter(): def __init__(self): self.losses = {} def log_loss(self, key, val, n=1): if not key in self.losses: self.losses[key] = AverageMeter() self.losses[key].update(val, n) def get_loss_string(self): loss_string = " | ".join([f"{key}: {self.losses[key].avg:.4f}" for key in self.losses]) return loss_string