gen6scp's picture
Patched codes for ZeroGPU
d643072
raw
history blame
3.05 kB
import os
import time
from copy import deepcopy
from typing import Optional
import torch.backends.cudnn
import torch.distributed
import torch.nn as nn
from ..apps.utils import (
dist_init,
dump_config,
get_dist_local_rank,
get_dist_rank,
get_dist_size,
init_modules,
is_master,
load_config,
partial_update_config,
zero_last_gamma,
)
from ..models.utils import build_kwargs_from_config, load_state_dict_from_file
__all__ = [
"save_exp_config",
"setup_dist_env",
"setup_seed",
"setup_exp_config",
"init_model",
]
def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
if not is_master():
return
dump_config(exp_config, os.path.join(path, name))
def setup_dist_env(gpu: Optional[str] = None) -> None:
if gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
if not torch.distributed.is_initialized():
dist_init()
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(get_dist_local_rank())
def setup_seed(manual_seed: int, resume: bool) -> None:
if resume:
manual_seed = int(time.time())
manual_seed = get_dist_rank() + manual_seed
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)
def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
# load config
if not os.path.isfile(config_path):
raise ValueError(config_path)
fpaths = [config_path]
if recursive:
extension = os.path.splitext(config_path)[1]
while os.path.dirname(config_path) != config_path:
config_path = os.path.dirname(config_path)
fpath = os.path.join(config_path, "default" + extension)
if os.path.isfile(fpath):
fpaths.append(fpath)
fpaths = fpaths[::-1]
default_config = load_config(fpaths[0])
exp_config = deepcopy(default_config)
for fpath in fpaths[1:]:
partial_update_config(exp_config, load_config(fpath))
# update config via args
if opt_args is not None:
partial_update_config(exp_config, opt_args)
return exp_config
def init_model(
network: nn.Module,
init_from: Optional[str] = None,
backbone_init_from: Optional[str] = None,
rand_init="trunc_normal",
last_gamma=None,
) -> None:
# initialization
init_modules(network, init_type=rand_init)
# zero gamma of last bn in each block
if last_gamma is not None:
zero_last_gamma(network, last_gamma)
# load weight
if init_from is not None and os.path.isfile(init_from):
network.load_state_dict(load_state_dict_from_file(init_from))
print(f"Loaded init from {init_from}")
elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
print(f"Loaded backbone init from {backbone_init_from}")
else:
print(f"Random init ({rand_init}) with last gamma {last_gamma}")