|
import gc |
|
from typing import Any, Dict, Union |
|
|
|
import torch |
|
from accelerate.logging import get_logger |
|
|
|
from finetune.constants import LOG_LEVEL, LOG_NAME |
|
|
|
|
|
logger = get_logger(LOG_NAME, LOG_LEVEL) |
|
|
|
|
|
def get_memory_statistics(precision: int = 3) -> Dict[str, Any]: |
|
memory_allocated = None |
|
memory_reserved = None |
|
max_memory_allocated = None |
|
max_memory_reserved = None |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.cuda.current_device() |
|
memory_allocated = torch.cuda.memory_allocated(device) |
|
memory_reserved = torch.cuda.memory_reserved(device) |
|
max_memory_allocated = torch.cuda.max_memory_allocated(device) |
|
max_memory_reserved = torch.cuda.max_memory_reserved(device) |
|
|
|
elif torch.mps.is_available(): |
|
memory_allocated = torch.mps.current_allocated_memory() |
|
|
|
else: |
|
logger.warning("No CUDA, MPS, or ROCm device found. Memory statistics are not available.") |
|
|
|
return { |
|
"memory_allocated": round(bytes_to_gigabytes(memory_allocated), ndigits=precision), |
|
"memory_reserved": round(bytes_to_gigabytes(memory_reserved), ndigits=precision), |
|
"max_memory_allocated": round(bytes_to_gigabytes(max_memory_allocated), ndigits=precision), |
|
"max_memory_reserved": round(bytes_to_gigabytes(max_memory_reserved), ndigits=precision), |
|
} |
|
|
|
|
|
def bytes_to_gigabytes(x: int) -> float: |
|
if x is not None: |
|
return x / 1024**3 |
|
|
|
|
|
def free_memory() -> None: |
|
if torch.cuda.is_available(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
|
|
|
|
|
|
def unload_model(model): |
|
model.to("cpu") |
|
|
|
|
|
def make_contiguous(x: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
if isinstance(x, torch.Tensor): |
|
return x.contiguous() |
|
elif isinstance(x, dict): |
|
return {k: make_contiguous(v) for k, v in x.items()} |
|
else: |
|
return x |
|
|