|
|
import gc |
|
|
import inspect |
|
|
from typing import Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from accelerate import Accelerator |
|
|
from accelerate.logging import get_logger |
|
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed |
|
|
from diffusers.utils.torch_utils import is_compiled_module |
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
def get_optimizer( |
|
|
params_to_optimize, |
|
|
optimizer_name: str = "adam", |
|
|
learning_rate: float = 1e-3, |
|
|
beta1: float = 0.9, |
|
|
beta2: float = 0.95, |
|
|
beta3: float = 0.98, |
|
|
epsilon: float = 1e-8, |
|
|
weight_decay: float = 1e-4, |
|
|
prodigy_decouple: bool = False, |
|
|
prodigy_use_bias_correction: bool = False, |
|
|
prodigy_safeguard_warmup: bool = False, |
|
|
use_8bit: bool = False, |
|
|
use_4bit: bool = False, |
|
|
use_torchao: bool = False, |
|
|
use_deepspeed: bool = False, |
|
|
use_cpu_offload_optimizer: bool = False, |
|
|
offload_gradients: bool = False, |
|
|
) -> torch.optim.Optimizer: |
|
|
optimizer_name = optimizer_name.lower() |
|
|
|
|
|
|
|
|
if use_deepspeed: |
|
|
from accelerate.utils import DummyOptim |
|
|
|
|
|
return DummyOptim( |
|
|
params_to_optimize, |
|
|
lr=learning_rate, |
|
|
betas=(beta1, beta2), |
|
|
eps=epsilon, |
|
|
weight_decay=weight_decay, |
|
|
) |
|
|
|
|
|
if use_8bit and use_4bit: |
|
|
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.") |
|
|
|
|
|
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer: |
|
|
try: |
|
|
import torchao |
|
|
|
|
|
torchao.__version__ |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`." |
|
|
) |
|
|
|
|
|
if not use_torchao and use_4bit: |
|
|
raise ValueError("4-bit Optimizers are only supported with torchao.") |
|
|
|
|
|
|
|
|
supported_optimizers = ["adam", "adamw", "prodigy", "came"] |
|
|
if optimizer_name not in supported_optimizers: |
|
|
logger.warning( |
|
|
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`." |
|
|
) |
|
|
optimizer_name = "adamw" |
|
|
|
|
|
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]: |
|
|
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.") |
|
|
|
|
|
if use_8bit: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." |
|
|
) |
|
|
|
|
|
if optimizer_name == "adamw": |
|
|
if use_torchao: |
|
|
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit |
|
|
|
|
|
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW |
|
|
else: |
|
|
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW |
|
|
|
|
|
init_kwargs = { |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
|
|
|
elif optimizer_name == "adam": |
|
|
if use_torchao: |
|
|
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit |
|
|
|
|
|
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam |
|
|
else: |
|
|
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam |
|
|
|
|
|
init_kwargs = { |
|
|
"betas": (beta1, beta2), |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
|
|
|
elif optimizer_name == "prodigy": |
|
|
try: |
|
|
import prodigyopt |
|
|
except ImportError: |
|
|
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") |
|
|
|
|
|
optimizer_class = prodigyopt.Prodigy |
|
|
|
|
|
if learning_rate <= 0.1: |
|
|
logger.warning( |
|
|
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" |
|
|
) |
|
|
|
|
|
init_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"betas": (beta1, beta2), |
|
|
"beta3": beta3, |
|
|
"eps": epsilon, |
|
|
"weight_decay": weight_decay, |
|
|
"decouple": prodigy_decouple, |
|
|
"use_bias_correction": prodigy_use_bias_correction, |
|
|
"safeguard_warmup": prodigy_safeguard_warmup, |
|
|
} |
|
|
|
|
|
elif optimizer_name == "came": |
|
|
try: |
|
|
import came_pytorch |
|
|
except ImportError: |
|
|
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`") |
|
|
|
|
|
optimizer_class = came_pytorch.CAME |
|
|
|
|
|
init_kwargs = { |
|
|
"lr": learning_rate, |
|
|
"eps": (1e-30, 1e-16), |
|
|
"betas": (beta1, beta2, beta3), |
|
|
"weight_decay": weight_decay, |
|
|
} |
|
|
|
|
|
if use_cpu_offload_optimizer: |
|
|
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer |
|
|
|
|
|
if "fused" in inspect.signature(optimizer_class.__init__).parameters: |
|
|
init_kwargs.update({"fused": True}) |
|
|
|
|
|
optimizer = CPUOffloadOptimizer( |
|
|
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs |
|
|
) |
|
|
else: |
|
|
optimizer = optimizer_class(params_to_optimize, **init_kwargs) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
def get_gradient_norm(parameters): |
|
|
norm = 0 |
|
|
for param in parameters: |
|
|
if param.grad is None: |
|
|
continue |
|
|
local_norm = param.grad.detach().data.norm(2) |
|
|
norm += local_norm.item() ** 2 |
|
|
norm = norm**0.5 |
|
|
return norm |
|
|
|
|
|
|
|
|
|
|
|
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): |
|
|
tw = tgt_width |
|
|
th = tgt_height |
|
|
h, w = src |
|
|
r = h / w |
|
|
if r > (th / tw): |
|
|
resize_height = th |
|
|
resize_width = int(round(th / h * w)) |
|
|
else: |
|
|
resize_width = tw |
|
|
resize_height = int(round(tw / w * h)) |
|
|
|
|
|
crop_top = int(round((th - resize_height) / 2.0)) |
|
|
crop_left = int(round((tw - resize_width) / 2.0)) |
|
|
|
|
|
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) |
|
|
|
|
|
|
|
|
def prepare_rotary_positional_embeddings( |
|
|
height: int, |
|
|
width: int, |
|
|
num_frames: int, |
|
|
vae_scale_factor_spatial: int = 8, |
|
|
patch_size: int = 2, |
|
|
patch_size_t: int = None, |
|
|
attention_head_dim: int = 64, |
|
|
device: Optional[torch.device] = None, |
|
|
base_height: int = 480, |
|
|
base_width: int = 720, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
grid_height = height // (vae_scale_factor_spatial * patch_size) |
|
|
grid_width = width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size) |
|
|
|
|
|
if patch_size_t is None: |
|
|
|
|
|
grid_crops_coords = get_resize_crop_region_for_grid( |
|
|
(grid_height, grid_width), base_size_width, base_size_height |
|
|
) |
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
|
embed_dim=attention_head_dim, |
|
|
crops_coords=grid_crops_coords, |
|
|
grid_size=(grid_height, grid_width), |
|
|
temporal_size=num_frames, |
|
|
) |
|
|
else: |
|
|
|
|
|
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t |
|
|
|
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
|
embed_dim=attention_head_dim, |
|
|
crops_coords=None, |
|
|
grid_size=(grid_height, grid_width), |
|
|
temporal_size=base_num_frames, |
|
|
grid_type="slice", |
|
|
max_size=(base_size_height, base_size_width), |
|
|
) |
|
|
|
|
|
freqs_cos = freqs_cos.to(device=device) |
|
|
freqs_sin = freqs_sin.to(device=device) |
|
|
return freqs_cos, freqs_sin |
|
|
|
|
|
|
|
|
def reset_memory(device: Union[str, torch.device]) -> None: |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_peak_memory_stats(device) |
|
|
torch.cuda.reset_accumulated_memory_stats(device) |
|
|
|
|
|
|
|
|
def print_memory(device: Union[str, torch.device]) -> None: |
|
|
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 |
|
|
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 |
|
|
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 |
|
|
print(f"{memory_allocated=:.3f} GB") |
|
|
print(f"{max_memory_allocated=:.3f} GB") |
|
|
print(f"{max_memory_reserved=:.3f} GB") |
|
|
|
|
|
|
|
|
def unwrap_model(accelerator: Accelerator, model): |
|
|
model = accelerator.unwrap_model(model) |
|
|
model = model._orig_mod if is_compiled_module(model) else model |
|
|
return model |
|
|
|