|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
class VariableLossLogPrinter(): |
|
def __init__(self): |
|
self.losses = {} |
|
|
|
def log_loss(self, key, val, n=1): |
|
if not key in self.losses: |
|
self.losses[key] = AverageMeter() |
|
self.losses[key].update(val, n) |
|
|
|
def get_loss_string(self): |
|
loss_string = " | ".join([f"{key}: {self.losses[key].avg:.4f}" for key in self.losses]) |
|
|
|
return loss_string |
|
|