|
from typing import Dict, List, Optional, Union |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from diffusers.utils.torch_utils import is_compiled_module |
|
|
|
|
|
def unwrap_model(accelerator: Accelerator, model): |
|
model = accelerator.unwrap_model(model) |
|
model = model._orig_mod if is_compiled_module(model) else model |
|
return model |
|
|
|
|
|
def align_device_and_dtype( |
|
x: Union[torch.Tensor, Dict[str, torch.Tensor]], |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
): |
|
if isinstance(x, torch.Tensor): |
|
if device is not None: |
|
x = x.to(device) |
|
if dtype is not None: |
|
x = x.to(dtype) |
|
elif isinstance(x, dict): |
|
if device is not None: |
|
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} |
|
if dtype is not None: |
|
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()} |
|
return x |
|
|
|
|
|
def expand_tensor_to_dims(tensor, ndim): |
|
while len(tensor.shape) < ndim: |
|
tensor = tensor.unsqueeze(-1) |
|
return tensor |
|
|
|
|
|
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): |
|
""" |
|
Casts the training parameters of the model to the specified data type. |
|
|
|
Args: |
|
model: The PyTorch model whose parameters will be cast. |
|
dtype: The data type to which the model parameters will be cast. |
|
""" |
|
if not isinstance(model, list): |
|
model = [model] |
|
for m in model: |
|
for param in m.parameters(): |
|
|
|
if param.requires_grad: |
|
param.data = param.to(dtype) |
|
|