File size: 4,626 Bytes
1da48bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch as t
import torch.distributed as dist
from tqdm import tqdm
from datetime import date
import os
import sys

import sys
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
from utils.torch_utils import parse_args


args = parse_args()
mydevice = t.device('cuda:' + args.gpu)

def def_tqdm(x):
    return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

def get_range(x):
    if dist.get_rank() == 0:
        return def_tqdm(x)
    else:
        return x

def init_logging(hps, local_rank, rank):
    logdir = f"{hps.local_logdir}/{hps.name}"
    if local_rank == 0:
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        with open(logdir + 'argv.txt', 'w') as f:
            f.write(hps.argv + '\n')
        print("Logging to", logdir)
    logger = Logger(logdir, rank)
    metrics = Metrics()
    logger.add_text('hps', str(hps))
    return logger, metrics

def get_name(hps):
    name = ""
    for key, value in hps.items():
        name += f"{key}_{value}_"
    return name

def average_metrics(_metrics):
    metrics = {}
    for _metric in _metrics:
        for key, val in _metric.items():
            if key not in metrics:
                metrics[key] = []
            metrics[key].append(val)
    return {key: sum(vals)//len(vals) for key, vals in metrics.items()}

class Metrics:
    def __init__(self):
        self.sum = {}
        self.n = {}

    def update(self, tag, val, batch):
        # v is average value over batch
        # store total value and total batch, returns dist average
        sum = t.tensor(val * batch).float().to(mydevice)
        n = t.tensor(batch).float().to(mydevice)
        dist.all_reduce(sum)
        dist.all_reduce(n)
        sum = sum.item()
        n = n.item()
        self.sum[tag] = self.sum.get(tag, 0.0) + sum
        self.n[tag] = self.n.get(tag, 0.0) + n
        return sum / n

    def avg(self, tag):
        if tag in self.sum:
            return self.sum[tag] / self.n[tag]
        else:
            return 0.0

    def reset(self):
        self.sum = {}
        self.n = {}

class Logger:
    def __init__(self, logdir, rank):
        if rank == 0:
            from tensorboardX import SummaryWriter
            self.sw = SummaryWriter(f"{logdir}/logs")
        self.iters = 0
        self.rank = rank
        self.works = []
        self.logdir = logdir

    def step(self):
        self.iters += 1

    def flush(self):
        if self.rank == 0:
            self.sw.flush()

    def add_text(self, tag, text):
        if self.rank == 0:
            self.sw.add_text(tag, text, self.iters)

    def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8):
        if self.rank == 0:
            for i in range(min(len(auds), max_log)):
                if max_len:
                    self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate)
                else:
                    self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate)

    def add_audio(self, tag, aud, sample_rate=22050):
        if self.rank == 0:
            self.sw.add_audio(tag, aud, self.iters, sample_rate)

    def add_images(self, tag, img, dataformats="NHWC"):
        if self.rank == 0:
            self.sw.add_images(tag, img, self.iters, dataformats=dataformats)

    def add_image(self, tag, img):
        if self.rank == 0:
            self.sw.add_image(tag, img, self.iters)

    def add_scalar(self, tag, val):
        if self.rank == 0:
            self.sw.add_scalar(tag, val, self.iters)

    def get_range(self, loader):
        if self.rank == 0:
            self.trange = def_tqdm(loader)
        else:
            self.trange = loader
        return enumerate(self.trange)

    def close_range(self):
        if self.rank == 0:
            self.trange.close()

    def set_postfix(self, *args, **kwargs):
        if self.rank == 0:
            self.trange.set_postfix(*args, **kwargs)

    # For logging summaries of varies graph ops
    def add_reduce_scalar(self, tag, layer, val):
        if self.iters % 100 == 0:
            with t.no_grad():
                val = val.float().norm()/float(val.numel())
            work = dist.reduce(val, 0, async_op=True)
            self.works.append((tag, layer, val, work))

    def finish_reduce(self):
        for tag, layer, val, work in self.works:
            work.wait()
            if self.rank == 0:
                val = val.item()/dist.get_world_size()
                self.lw[layer].add_scalar(tag, val, self.iters)
        self.works = []