File size: 872 Bytes
8d4ee22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

#from robustness.robustness.tools.helpers  https://github.com/MadryLab/robustness
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class VariableLossLogPrinter():
    def __init__(self):
        self.losses = {}

    def log_loss(self, key, val, n=1):
        if not key in self.losses:
            self.losses[key] = AverageMeter()
        self.losses[key].update(val, n)

    def get_loss_string(self):
        loss_string = " | ".join([f"{key}: {self.losses[key].avg:.4f}" for key in self.losses])

        return loss_string