Spaces:
Running
on
Zero
Running
on
Zero
import torch as t | |
import torch.distributed as dist | |
from tqdm import tqdm | |
from datetime import date | |
import os | |
import sys | |
import sys | |
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0]))) | |
from utils.torch_utils import parse_args | |
args = parse_args() | |
mydevice = t.device('cuda:' + args.gpu) | |
def def_tqdm(x): | |
return tqdm(x, leave=True, file=sys.stdout, bar_format="{n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]") | |
def get_range(x): | |
if dist.get_rank() == 0: | |
return def_tqdm(x) | |
else: | |
return x | |
def init_logging(hps, local_rank, rank): | |
logdir = f"{hps.local_logdir}/{hps.name}" | |
if local_rank == 0: | |
if not os.path.exists(logdir): | |
os.makedirs(logdir) | |
with open(logdir + 'argv.txt', 'w') as f: | |
f.write(hps.argv + '\n') | |
print("Logging to", logdir) | |
logger = Logger(logdir, rank) | |
metrics = Metrics() | |
logger.add_text('hps', str(hps)) | |
return logger, metrics | |
def get_name(hps): | |
name = "" | |
for key, value in hps.items(): | |
name += f"{key}_{value}_" | |
return name | |
def average_metrics(_metrics): | |
metrics = {} | |
for _metric in _metrics: | |
for key, val in _metric.items(): | |
if key not in metrics: | |
metrics[key] = [] | |
metrics[key].append(val) | |
return {key: sum(vals)//len(vals) for key, vals in metrics.items()} | |
class Metrics: | |
def __init__(self): | |
self.sum = {} | |
self.n = {} | |
def update(self, tag, val, batch): | |
# v is average value over batch | |
# store total value and total batch, returns dist average | |
sum = t.tensor(val * batch).float().to(mydevice) | |
n = t.tensor(batch).float().to(mydevice) | |
dist.all_reduce(sum) | |
dist.all_reduce(n) | |
sum = sum.item() | |
n = n.item() | |
self.sum[tag] = self.sum.get(tag, 0.0) + sum | |
self.n[tag] = self.n.get(tag, 0.0) + n | |
return sum / n | |
def avg(self, tag): | |
if tag in self.sum: | |
return self.sum[tag] / self.n[tag] | |
else: | |
return 0.0 | |
def reset(self): | |
self.sum = {} | |
self.n = {} | |
class Logger: | |
def __init__(self, logdir, rank): | |
if rank == 0: | |
from tensorboardX import SummaryWriter | |
self.sw = SummaryWriter(f"{logdir}/logs") | |
self.iters = 0 | |
self.rank = rank | |
self.works = [] | |
self.logdir = logdir | |
def step(self): | |
self.iters += 1 | |
def flush(self): | |
if self.rank == 0: | |
self.sw.flush() | |
def add_text(self, tag, text): | |
if self.rank == 0: | |
self.sw.add_text(tag, text, self.iters) | |
def add_audios(self, tag, auds, sample_rate=22050, max_len=None, max_log=8): | |
if self.rank == 0: | |
for i in range(min(len(auds), max_log)): | |
if max_len: | |
self.sw.add_audio(f"{i}/{tag}", auds[i][:max_len * sample_rate], self.iters, sample_rate) | |
else: | |
self.sw.add_audio(f"{i}/{tag}", auds[i], self.iters, sample_rate) | |
def add_audio(self, tag, aud, sample_rate=22050): | |
if self.rank == 0: | |
self.sw.add_audio(tag, aud, self.iters, sample_rate) | |
def add_images(self, tag, img, dataformats="NHWC"): | |
if self.rank == 0: | |
self.sw.add_images(tag, img, self.iters, dataformats=dataformats) | |
def add_image(self, tag, img): | |
if self.rank == 0: | |
self.sw.add_image(tag, img, self.iters) | |
def add_scalar(self, tag, val): | |
if self.rank == 0: | |
self.sw.add_scalar(tag, val, self.iters) | |
def get_range(self, loader): | |
if self.rank == 0: | |
self.trange = def_tqdm(loader) | |
else: | |
self.trange = loader | |
return enumerate(self.trange) | |
def close_range(self): | |
if self.rank == 0: | |
self.trange.close() | |
def set_postfix(self, *args, **kwargs): | |
if self.rank == 0: | |
self.trange.set_postfix(*args, **kwargs) | |
# For logging summaries of varies graph ops | |
def add_reduce_scalar(self, tag, layer, val): | |
if self.iters % 100 == 0: | |
with t.no_grad(): | |
val = val.float().norm()/float(val.numel()) | |
work = dist.reduce(val, 0, async_op=True) | |
self.works.append((tag, layer, val, work)) | |
def finish_reduce(self): | |
for tag, layer, val, work in self.works: | |
work.wait() | |
if self.rank == 0: | |
val = val.item()/dist.get_world_size() | |
self.lw[layer].add_scalar(tag, val, self.iters) | |
self.works = [] | |