import json
import os
import random
import re
import subprocess
import sys
import time
from collections import OrderedDict
from typing import Optional, Union
import numpy as np
import torch
from tap import Tap
except ImportError as e:
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
raise e
import dist
class Args(Tap):
data_path: str = '/path/to/imagenet'
exp_name: str = 'text'
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
depth: int = 16 # VAR depth
# VAR initialization
ini: float = -1 # -1: automated model parameter initialization
hd: float = 0.02 # head.w *= hd
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
# VAR optimization
fp16: int = 0 # 1: using fp16, 2: bf16
tblr: float = 1e-4 # base lr
tlr: float = None # lr = base lr * (bs / 256)
twd: float = 0.05 # initial wd
twde: float = 0 # final wd, =twde or twd
tclip: float = 2. # <=0 for not using grad clip
ls: float = 0.0 # label smooth
bs: int = 768 # global batch size
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
ac: int = 1 # gradient accumulation
ep: int = 250
wp: float = 0
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
wpe: float = 0.01 # final lr ratio at the end of training
sche: str = 'lin0' # lr schedule
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
afuse: bool = True # fused adamw
# other hps
saln: bool = False # whether to use shared adaln
anorm: bool = True # whether to use L2 normalized attention
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
# data
pn: str = '1_2_3_4_5_6_8_10_13_16'
patch_size: int = 16
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
hflip: bool = False # augmentation: horizontal flip
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
# progressive training
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
pgwp: float = 0 # num of warmup epochs at each progressive stage
# would be automatically set in runtime
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
acc_mean: float = None # [automatically set; don't specify this]
acc_tail: float = None # [automatically set; don't specify this]
L_mean: float = None # [automatically set; don't specify this]
L_tail: float = None # [automatically set; don't specify this]
vacc_mean: float = None # [automatically set; don't specify this]
vacc_tail: float = None # [automatically set; don't specify this]
vL_mean: float = None # [automatically set; don't specify this]
vL_tail: float = None # [automatically set; don't specify this]
grad_norm: float = None # [automatically set; don't specify this]
cur_lr: float = None # [automatically set; don't specify this]
cur_wd: float = None # [automatically set; don't specify this]
cur_it: str = '' # [automatically set; don't specify this]
cur_ep: str = '' # [automatically set; don't specify this]
remain_time: str = '' # [automatically set; don't specify this]
finish_time: str = '' # [automatically set; don't specify this]
# environment
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
log_txt_path: str = '...' # [automatically set; don't specify this]
last_ckpt_path: str = '...' # [automatically set; don't specify this]
tf32: bool = True # whether to use TensorFloat32
device: str = 'cpu' # [automatically set; don't specify this]
seed: int = None # seed
def seed_everything(self, benchmark: bool):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = benchmark
if self.seed is None:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.deterministic = True
seed = self.seed * dist.get_world_size() + dist.get_rank()
os.environ['PYTHONHASHSEED'] = str(seed)
if torch.cuda.is_available():
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
local_debug: bool = 'KEVIN_LOCAL' in os.environ
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
def compile_model(self, m, fast):
if fast == 0 or self.local_debug:
return m
return torch.compile(m, mode={
1: 'reduce-overhead',
2: 'max-autotune',
3: 'default',
}[fast]) if hasattr(torch, 'compile') else m
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
# self.as_dict() would contain methods, but we only need variables
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
if isinstance(d, str): # for compatibility with old version
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
for k in d.keys():
setattr(self, k, d[k])
except Exception as e:
print(f'k={k}, v={d[k]}')
raise e
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def dump_log(self):
if not dist.is_local_master():
if '1/' in self.cur_ep: # first time to dump log
with open(self.log_txt_path, 'w') as fp:
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
log_dict = {}
for k, v in {
'it': self.cur_it, 'ep': self.cur_ep,
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
'remain_time': self.remain_time, 'finish_time': self.finish_time,
if hasattr(v, 'item'): v = v.item()
log_dict[k] = v
with open(self.log_txt_path, 'a') as fp:
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
args = Args(explicit_bool=True).parse_args(known_only=True)
if args.local_debug:
args.pn = '1_2_3'
args.seed = 1
args.aln = 1e-2
args.alng = 1e-5
args.saln = False
args.afuse = False
args.pg = 0.8
args.pg0 = 1
if args.data_path == '/path/to/imagenet':
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
# init torch distributed
from utils import misc
os.makedirs(args.local_out_dir_path, exist_ok=True)
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
# set env
args.seed_everything(benchmark=args.pg == 0)
# update args: data loading
args.device = dist.get_device()
if args.pn == '256':
args.pn = '1_2_3_4_5_6_8_10_13_16'
elif args.pn == '512':
args.pn = '1_2_3_4_6_9_13_18_24_32'
elif args.pn == '1024':
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
args.data_load_reso = max(args.resos)
# update args: bs and lr
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
args.batch_size = bs_per_gpu
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
args.workers = min(max(0, args.workers), args.batch_size)
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
args.twde = args.twde or args.twd
if args.wp == 0:
args.wp = args.ep * 1/50
# update args: progressive training
if args.pgwp == 0:
args.pgwp = args.ep * 1/300
if args.pg > 0:
args.sche = f'lin{args.pg:g}'
# update args: paths
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
_reg_valid_name = re.compile(r'[^\w\-+,.]')
tb_name = _reg_valid_name.sub(
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
return args