Spaces:
Runtime error
Runtime error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import json | |
import numpy as np | |
import torch.nn as nn | |
from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer | |
__all__ = ["Scheduler", "RunConfig"] | |
class Scheduler: | |
PROGRESS = 0 | |
class RunConfig: | |
n_epochs: int | |
init_lr: float | |
warmup_epochs: int | |
warmup_lr: float | |
lr_schedule_name: str | |
lr_schedule_param: dict | |
optimizer_name: str | |
optimizer_params: dict | |
weight_decay: float | |
no_wd_keys: list | |
grad_clip: float # allow none to turn off grad clipping | |
reset_bn: bool | |
reset_bn_size: int | |
reset_bn_batch_size: int | |
eval_image_size: list # allow none to use image_size in data_provider | |
def none_allowed(self): | |
return ["grad_clip", "eval_image_size"] | |
def __init__(self, **kwargs): # arguments must be passed as kwargs | |
for k, val in kwargs.items(): | |
setattr(self, k, val) | |
# check that all relevant configs are there | |
annotations = {} | |
for clas in type(self).mro(): | |
if hasattr(clas, "__annotations__"): | |
annotations.update(clas.__annotations__) | |
for k, k_type in annotations.items(): | |
assert hasattr( | |
self, k | |
), f"Key {k} with type {k_type} required for initialization." | |
attr = getattr(self, k) | |
if k in self.none_allowed: | |
k_type = (k_type, type(None)) | |
assert isinstance( | |
attr, k_type | |
), f"Key {k} must be type {k_type}, provided={attr}." | |
self.global_step = 0 | |
self.batch_per_epoch = 1 | |
def build_optimizer(self, network: nn.Module) -> tuple[any, any]: | |
r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler""" | |
param_dict = {} | |
for name, param in network.named_parameters(): | |
if param.requires_grad: | |
opt_config = [self.weight_decay, self.init_lr] | |
if self.no_wd_keys is not None and len(self.no_wd_keys) > 0: | |
if np.any([key in name for key in self.no_wd_keys]): | |
opt_config[0] = 0 | |
opt_key = json.dumps(opt_config) | |
param_dict[opt_key] = param_dict.get(opt_key, []) + [param] | |
net_params = [] | |
for opt_key, param_list in param_dict.items(): | |
wd, lr = json.loads(opt_key) | |
net_params.append({"params": param_list, "weight_decay": wd, "lr": lr}) | |
optimizer = build_optimizer( | |
net_params, self.optimizer_name, self.optimizer_params, self.init_lr | |
) | |
# build lr scheduler | |
if self.lr_schedule_name == "cosine": | |
decay_steps = [] | |
for epoch in self.lr_schedule_param.get("step", []): | |
decay_steps.append(epoch * self.batch_per_epoch) | |
decay_steps.append(self.n_epochs * self.batch_per_epoch) | |
decay_steps.sort() | |
lr_scheduler = CosineLRwithWarmup( | |
optimizer, | |
self.warmup_epochs * self.batch_per_epoch, | |
self.warmup_lr, | |
decay_steps, | |
) | |
else: | |
raise NotImplementedError | |
return optimizer, lr_scheduler | |
def update_global_step(self, epoch, batch_id=0) -> None: | |
self.global_step = epoch * self.batch_per_epoch + batch_id | |
Scheduler.PROGRESS = self.progress | |
def progress(self) -> float: | |
warmup_steps = self.warmup_epochs * self.batch_per_epoch | |
steps = max(0, self.global_step - warmup_steps) | |
return steps / (self.n_epochs * self.batch_per_epoch) | |
def step(self) -> None: | |
self.global_step += 1 | |
Scheduler.PROGRESS = self.progress | |
def get_remaining_epoch(self, epoch, post=True) -> int: | |
return self.n_epochs + self.warmup_epochs - epoch - int(post) | |
def epoch_format(self, epoch: int) -> str: | |
epoch_format = f"%.{len(str(self.n_epochs))}d" | |
epoch_format = f"[{epoch_format}/{epoch_format}]" | |
epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs) | |
return epoch_format | |