Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,466 Bytes
32287b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import wandb
import torch
from torchvision.utils import make_grid
import torch.distributed as dist
from PIL import Image
import os
import argparse
import hashlib
import math
def is_main_process():
return dist.get_rank() == 0
def namespace_to_dict(namespace):
return {
k: namespace_to_dict(v) if isinstance(v, argparse.Namespace) else v
for k, v in vars(namespace).items()
}
def generate_run_id(exp_name):
# https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
return str(int(hashlib.sha256(exp_name.encode('utf-8')).hexdigest(), 16) % 10 ** 8)
def initialize(args, entity, exp_name, project_name):
config_dict = namespace_to_dict(args)
wandb.login(key=os.environ["WANDB_KEY"])
wandb.init(
entity=entity,
project=project_name,
name=exp_name,
config=config_dict,
id=generate_run_id(exp_name),
resume="allow",
)
def log(stats, step=None):
if is_main_process():
wandb.log({k: v for k, v in stats.items()}, step=step)
def log_image(name, sample, step=None):
if is_main_process():
sample = array2grid(sample)
wandb.log({f"{name}": wandb.Image(sample), "train_step": step})
def array2grid(x):
nrow = round(math.sqrt(x.size(0)))
x = make_grid(x, nrow=nrow, normalize=True, value_range=(-1,1))
x = x.mul(255).add_(0.5).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
return x |