Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import random | |
| import re | |
| import subprocess | |
| import sys | |
| import time | |
| from collections import OrderedDict | |
| from typing import Optional, Union | |
| import numpy as np | |
| import torch | |
| try: | |
| from tap import Tap | |
| except ImportError as e: | |
| print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) | |
| print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) | |
| time.sleep(5) | |
| raise e | |
| import dist | |
| class Args(Tap): | |
| data_path: str = '/path/to/imagenet' | |
| exp_name: str = 'text' | |
| # VAE | |
| vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' | |
| # VAR | |
| tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' | |
| depth: int = 16 # VAR depth | |
| # VAR initialization | |
| ini: float = -1 # -1: automated model parameter initialization | |
| hd: float = 0.02 # head.w *= hd | |
| aln: float = 0.5 # the multiplier of ada_lin.w's initialization | |
| alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization | |
| # VAR optimization | |
| fp16: int = 0 # 1: using fp16, 2: bf16 | |
| tblr: float = 1e-4 # base lr | |
| tlr: float = None # lr = base lr * (bs / 256) | |
| twd: float = 0.05 # initial wd | |
| twde: float = 0 # final wd, =twde or twd | |
| tclip: float = 2. # <=0 for not using grad clip | |
| ls: float = 0.0 # label smooth | |
| bs: int = 768 # global batch size | |
| batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8 | |
| glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() | |
| ac: int = 1 # gradient accumulation | |
| ep: int = 250 | |
| wp: float = 0 | |
| wp0: float = 0.005 # initial lr ratio at the begging of lr warm up | |
| wpe: float = 0.01 # final lr ratio at the end of training | |
| sche: str = 'lin0' # lr schedule | |
| opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work | |
| afuse: bool = True # fused adamw | |
| # other hps | |
| saln: bool = False # whether to use shared adaln | |
| anorm: bool = True # whether to use L2 normalized attention | |
| fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc. | |
| # data | |
| pn: str = '1_2_3_4_5_6_8_10_13_16' | |
| patch_size: int = 16 | |
| patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) | |
| resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums) | |
| data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size | |
| mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso | |
| hflip: bool = False # augmentation: horizontal flip | |
| workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader | |
| # progressive training | |
| pg: float = 0.0 # >0 for use progressive training during [0%, this] of training | |
| pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc | |
| pgwp: float = 0 # num of warmup epochs at each progressive stage | |
| # would be automatically set in runtime | |
| cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this] | |
| branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] | |
| acc_mean: float = None # [automatically set; don't specify this] | |
| acc_tail: float = None # [automatically set; don't specify this] | |
| L_mean: float = None # [automatically set; don't specify this] | |
| L_tail: float = None # [automatically set; don't specify this] | |
| vacc_mean: float = None # [automatically set; don't specify this] | |
| vacc_tail: float = None # [automatically set; don't specify this] | |
| vL_mean: float = None # [automatically set; don't specify this] | |
| vL_tail: float = None # [automatically set; don't specify this] | |
| grad_norm: float = None # [automatically set; don't specify this] | |
| cur_lr: float = None # [automatically set; don't specify this] | |
| cur_wd: float = None # [automatically set; don't specify this] | |
| cur_it: str = '' # [automatically set; don't specify this] | |
| cur_ep: str = '' # [automatically set; don't specify this] | |
| remain_time: str = '' # [automatically set; don't specify this] | |
| finish_time: str = '' # [automatically set; don't specify this] | |
| # environment | |
| local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this] | |
| tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this] | |
| log_txt_path: str = '...' # [automatically set; don't specify this] | |
| last_ckpt_path: str = '...' # [automatically set; don't specify this] | |
| tf32: bool = True # whether to use TensorFloat32 | |
| device: str = 'cpu' # [automatically set; don't specify this] | |
| seed: int = None # seed | |
| def seed_everything(self, benchmark: bool): | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = benchmark | |
| if self.seed is None: | |
| torch.backends.cudnn.deterministic = False | |
| else: | |
| torch.backends.cudnn.deterministic = True | |
| seed = self.seed * dist.get_world_size() + dist.get_rank() | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| same_seed_for_all_ranks: int = 0 # this is only for distributed sampler | |
| def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation | |
| if self.seed is None: | |
| return None | |
| g = torch.Generator() | |
| g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) | |
| return g | |
| local_debug: bool = 'KEVIN_LOCAL' in os.environ | |
| dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ | |
| def compile_model(self, m, fast): | |
| if fast == 0 or self.local_debug: | |
| return m | |
| return torch.compile(m, mode={ | |
| 1: 'reduce-overhead', | |
| 2: 'max-autotune', | |
| 3: 'default', | |
| }[fast]) if hasattr(torch, 'compile') else m | |
| def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: | |
| d = (OrderedDict if key_ordered else dict)() | |
| # self.as_dict() would contain methods, but we only need variables | |
| for k in self.class_variables.keys(): | |
| if k not in {'device'}: # these are not serializable | |
| d[k] = getattr(self, k) | |
| return d | |
| def load_state_dict(self, d: Union[OrderedDict, dict, str]): | |
| if isinstance(d, str): # for compatibility with old version | |
| d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l])) | |
| for k in d.keys(): | |
| try: | |
| setattr(self, k, d[k]) | |
| except Exception as e: | |
| print(f'k={k}, v={d[k]}') | |
| raise e | |
| def set_tf32(tf32: bool): | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.allow_tf32 = bool(tf32) | |
| torch.backends.cuda.matmul.allow_tf32 = bool(tf32) | |
| if hasattr(torch, 'set_float32_matmul_precision'): | |
| torch.set_float32_matmul_precision('high' if tf32 else 'highest') | |
| print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') | |
| print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') | |
| print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') | |
| def dump_log(self): | |
| if not dist.is_local_master(): | |
| return | |
| if '1/' in self.cur_ep: # first time to dump log | |
| with open(self.log_txt_path, 'w') as fp: | |
| json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0) | |
| fp.write('\n') | |
| log_dict = {} | |
| for k, v in { | |
| 'it': self.cur_it, 'ep': self.cur_ep, | |
| 'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm, | |
| 'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail, | |
| 'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail, | |
| 'remain_time': self.remain_time, 'finish_time': self.finish_time, | |
| }.items(): | |
| if hasattr(v, 'item'): v = v.item() | |
| log_dict[k] = v | |
| with open(self.log_txt_path, 'a') as fp: | |
| fp.write(f'{log_dict}\n') | |
| def __str__(self): | |
| s = [] | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| s.append(f' {k:20s}: {getattr(self, k)}') | |
| s = '\n'.join(s) | |
| return f'{{\n{s}\n}}\n' | |
| def init_dist_and_get_args(): | |
| for i in range(len(sys.argv)): | |
| if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): | |
| del sys.argv[i] | |
| break | |
| args = Args(explicit_bool=True).parse_args(known_only=True) | |
| if args.local_debug: | |
| args.pn = '1_2_3' | |
| args.seed = 1 | |
| args.aln = 1e-2 | |
| args.alng = 1e-5 | |
| args.saln = False | |
| args.afuse = False | |
| args.pg = 0.8 | |
| args.pg0 = 1 | |
| else: | |
| if args.data_path == '/path/to/imagenet': | |
| raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}') | |
| # warn args.extra_args | |
| if len(args.extra_args) > 0: | |
| print(f'======================================================================================') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') | |
| print(f'======================================================================================\n\n') | |
| # init torch distributed | |
| from utils import misc | |
| os.makedirs(args.local_out_dir_path, exist_ok=True) | |
| misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30) | |
| # set env | |
| args.set_tf32(args.tf32) | |
| args.seed_everything(benchmark=args.pg == 0) | |
| # update args: data loading | |
| args.device = dist.get_device() | |
| if args.pn == '256': | |
| args.pn = '1_2_3_4_5_6_8_10_13_16' | |
| elif args.pn == '512': | |
| args.pn = '1_2_3_4_6_9_13_18_24_32' | |
| elif args.pn == '1024': | |
| args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64' | |
| args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_'))) | |
| args.resos = tuple(pn * args.patch_size for pn in args.patch_nums) | |
| args.data_load_reso = max(args.resos) | |
| # update args: bs and lr | |
| bs_per_gpu = round(args.bs / args.ac / dist.get_world_size()) | |
| args.batch_size = bs_per_gpu | |
| args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() | |
| args.workers = min(max(0, args.workers), args.batch_size) | |
| args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 | |
| args.twde = args.twde or args.twd | |
| if args.wp == 0: | |
| args.wp = args.ep * 1/50 | |
| # update args: progressive training | |
| if args.pgwp == 0: | |
| args.pgwp = args.ep * 1/300 | |
| if args.pg > 0: | |
| args.sche = f'lin{args.pg:g}' | |
| # update args: paths | |
| args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt') | |
| args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth') | |
| _reg_valid_name = re.compile(r'[^\w\-+,.]') | |
| tb_name = _reg_valid_name.sub( | |
| '_', | |
| f'tb-VARd{args.depth}' | |
| f'__pn{args.pn}' | |
| f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}' | |
| ) | |
| args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name) | |
| return args | |