robinwitch's picture
update
1da48bb
import torch as t
import torch.distributed as dist
from tqdm import tqdm
from datetime import date
import os
import sys
import sys
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))
from utils.torch_utils import parse_args
args = parse_args()
mydevice = t.device('cuda:' + args.gpu)
def def_tqdm(x):
return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")
def get_range(x):
if dist.get_rank() == 0:
return def_tqdm(x)
else:
return x
def init_logging(hps, local_rank, rank):
logdir = f"{hps.local_logdir}/{hps.name}"
if local_rank == 0:
if not os.path.exists(logdir):
os.makedirs(logdir)
with open(logdir + 'argv.txt', 'w') as f:
f.write(hps.argv + '\n')
print("Logging to", logdir)
logger = Logger(logdir, rank)
metrics = Metrics()
logger.add_text('hps', str(hps))
return logger, metrics
def get_name(hps):
name = ""
for key, value in hps.items():
name += f"{key}_{value}_"
return name
def average_metrics(_metrics):
metrics = {}
for _metric in _metrics:
for key, val in _metric.items():
if key not in metrics:
metrics[key] = []
metrics[key].append(val)
return {key: sum(vals)//len(vals) for key, vals in metrics.items()}
class Metrics:
def __init__(self):
self.sum = {}
self.n = {}
def update(self, tag, val, batch):
# v is average value over batch
# store total value and total batch, returns dist average
sum = t.tensor(val * batch).float().to(mydevice)
n = t.tensor(batch).float().to(mydevice)
dist.all_reduce(sum)
dist.all_reduce(n)
sum = sum.item()
n = n.item()
self.sum[tag] = self.sum.get(tag, 0.0) + sum
self.n[tag] = self.n.get(tag, 0.0) + n
return sum / n
def avg(self, tag):
if tag in self.sum:
return self.sum[tag] / self.n[tag]
else:
return 0.0
def reset(self):
self.sum = {}
self.n = {}
class Logger:
def __init__(self, logdir, rank):
if rank == 0:
from tensorboardX import SummaryWriter
self.sw = SummaryWriter(f"{logdir}/logs")
self.iters = 0
self.rank = rank
self.works = []
self.logdir = logdir
def step(self):
self.iters += 1
def flush(self):
if self.rank == 0:
self.sw.flush()
def add_text(self, tag, text):
if self.rank == 0:
self.sw.add_text(tag, text, self.iters)
def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8):
if self.rank == 0:
for i in range(min(len(auds), max_log)):
if max_len:
self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate)
else:
self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate)
def add_audio(self, tag, aud, sample_rate=22050):
if self.rank == 0:
self.sw.add_audio(tag, aud, self.iters, sample_rate)
def add_images(self, tag, img, dataformats="NHWC"):
if self.rank == 0:
self.sw.add_images(tag, img, self.iters, dataformats=dataformats)
def add_image(self, tag, img):
if self.rank == 0:
self.sw.add_image(tag, img, self.iters)
def add_scalar(self, tag, val):
if self.rank == 0:
self.sw.add_scalar(tag, val, self.iters)
def get_range(self, loader):
if self.rank == 0:
self.trange = def_tqdm(loader)
else:
self.trange = loader
return enumerate(self.trange)
def close_range(self):
if self.rank == 0:
self.trange.close()
def set_postfix(self, *args, **kwargs):
if self.rank == 0:
self.trange.set_postfix(*args, **kwargs)
# For logging summaries of varies graph ops
def add_reduce_scalar(self, tag, layer, val):
if self.iters % 100 == 0:
with t.no_grad():
val = val.float().norm()/float(val.numel())
work = dist.reduce(val, 0, async_op=True)
self.works.append((tag, layer, val, work))
def finish_reduce(self):
for tag, layer, val, work in self.works:
work.wait()
if self.rank == 0:
val = val.item()/dist.get_world_size()
self.lw[layer].add_scalar(tag, val, self.iters)
self.works = []