|
import gc |
|
import os |
|
import re |
|
import time |
|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
|
|
import psutil |
|
import torch |
|
from packaging import version |
|
|
|
from .config import config_to_primitive |
|
from .core import debug, find, info, warn |
|
from .typing import * |
|
|
|
|
|
def parse_version(ver: str): |
|
return version.parse(ver) |
|
|
|
|
|
def get_rank(): |
|
|
|
|
|
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") |
|
for key in rank_keys: |
|
rank = os.environ.get(key) |
|
if rank is not None: |
|
return int(rank) |
|
return 0 |
|
|
|
|
|
def get_device(): |
|
return torch.device(f"cuda:{get_rank()}") |
|
|
|
|
|
def load_module_weights( |
|
path, module_name=None, ignore_modules=None, mapping=None, map_location=None |
|
) -> Tuple[dict, int, int]: |
|
if module_name is not None and ignore_modules is not None: |
|
raise ValueError("module_name and ignore_modules cannot be both set") |
|
if map_location is None: |
|
map_location = get_device() |
|
|
|
ckpt = torch.load(path, map_location=map_location) |
|
state_dict = ckpt["state_dict"] |
|
|
|
if mapping is not None: |
|
state_dict_to_load = {} |
|
for k, v in state_dict.items(): |
|
if any([k.startswith(m["to"]) for m in mapping]): |
|
pass |
|
else: |
|
state_dict_to_load[k] = v |
|
for k, v in state_dict.items(): |
|
for m in mapping: |
|
if k.startswith(m["from"]): |
|
k_dest = k.replace(m["from"], m["to"]) |
|
info(f"Mapping {k} => {k_dest}") |
|
state_dict_to_load[k_dest] = v.clone() |
|
state_dict = state_dict_to_load |
|
|
|
state_dict_to_load = state_dict |
|
|
|
if ignore_modules is not None: |
|
state_dict_to_load = {} |
|
for k, v in state_dict.items(): |
|
ignore = any( |
|
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules] |
|
) |
|
if ignore: |
|
continue |
|
state_dict_to_load[k] = v |
|
|
|
if module_name is not None: |
|
state_dict_to_load = {} |
|
for k, v in state_dict.items(): |
|
m = re.match(rf"^{module_name}\.(.*)$", k) |
|
if m is None: |
|
continue |
|
state_dict_to_load[m.group(1)] = v |
|
|
|
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] |
|
|
|
|
|
def C(value: Any, epoch: int, global_step: int) -> float: |
|
if isinstance(value, int) or isinstance(value, float): |
|
pass |
|
else: |
|
value = config_to_primitive(value) |
|
if not isinstance(value, list): |
|
raise TypeError("Scalar specification only supports list, got", type(value)) |
|
if len(value) == 3: |
|
value = [0] + value |
|
assert len(value) == 4 |
|
start_step, start_value, end_value, end_step = value |
|
if isinstance(end_step, int): |
|
current_step = global_step |
|
value = start_value + (end_value - start_value) * max( |
|
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
|
) |
|
elif isinstance(end_step, float): |
|
current_step = epoch |
|
value = start_value + (end_value - start_value) * max( |
|
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 |
|
) |
|
return value |
|
|
|
|
|
def cleanup(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
try: |
|
import tinycudann as tcnn |
|
|
|
tcnn.free_temporary_memory() |
|
except: |
|
pass |
|
|
|
|
|
def finish_with_cleanup(func: Callable): |
|
def wrapper(*args, **kwargs): |
|
out = func(*args, **kwargs) |
|
cleanup() |
|
return out |
|
|
|
return wrapper |
|
|
|
|
|
def _distributed_available(): |
|
return torch.distributed.is_available() and torch.distributed.is_initialized() |
|
|
|
|
|
def barrier(): |
|
if not _distributed_available(): |
|
return |
|
else: |
|
torch.distributed.barrier() |
|
|
|
|
|
def broadcast(tensor, src=0): |
|
if not _distributed_available(): |
|
return tensor |
|
else: |
|
torch.distributed.broadcast(tensor, src=src) |
|
return tensor |
|
|
|
|
|
def enable_gradient(model, enabled: bool = True) -> None: |
|
for param in model.parameters(): |
|
param.requires_grad_(enabled) |
|
|
|
|
|
class TimeRecorder: |
|
_instance = None |
|
|
|
def __init__(self): |
|
self.items = {} |
|
self.accumulations = defaultdict(list) |
|
self.time_scale = 1000.0 |
|
self.time_unit = "ms" |
|
self.enabled = False |
|
|
|
def __new__(cls): |
|
|
|
if cls._instance is None: |
|
cls._instance = super(TimeRecorder, cls).__new__(cls) |
|
return cls._instance |
|
|
|
def enable(self, enabled: bool) -> None: |
|
self.enabled = enabled |
|
|
|
def start(self, name: str) -> None: |
|
if not self.enabled: |
|
return |
|
torch.cuda.synchronize() |
|
self.items[name] = time.time() |
|
|
|
def end(self, name: str, accumulate: bool = False) -> float: |
|
if not self.enabled or name not in self.items: |
|
return |
|
torch.cuda.synchronize() |
|
start_time = self.items.pop(name) |
|
delta = time.time() - start_time |
|
if accumulate: |
|
self.accumulations[name].append(delta) |
|
t = delta * self.time_scale |
|
info(f"{name}: {t:.2f}{self.time_unit}") |
|
|
|
def get_accumulation(self, name: str, average: bool = False) -> float: |
|
if not self.enabled or name not in self.accumulations: |
|
return |
|
acc = self.accumulations.pop(name) |
|
total = sum(acc) |
|
if average: |
|
t = total / len(acc) * self.time_scale |
|
else: |
|
t = total * self.time_scale |
|
info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") |
|
|
|
|
|
|
|
time_recorder = TimeRecorder() |
|
|
|
|
|
@contextmanager |
|
def time_recorder_enabled(): |
|
enabled = time_recorder.enabled |
|
time_recorder.enable(enabled=True) |
|
try: |
|
yield |
|
finally: |
|
time_recorder.enable(enabled=enabled) |
|
|
|
|
|
def show_vram_usage(name): |
|
available, total = torch.cuda.mem_get_info() |
|
used = total - available |
|
print( |
|
f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB" |
|
) |
|
|