ReubenSun's picture
1
2ac1c2d
raw
history blame
6.28 kB
import gc
import os
import re
import time
from collections import defaultdict
from contextlib import contextmanager
import psutil
import torch
from packaging import version
from .config import config_to_primitive
from .core import debug, find, info, warn
from .typing import *
def parse_version(ver: str):
return version.parse(ver)
def get_rank():
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
def get_device():
return torch.device(f"cuda:{get_rank()}")
def load_module_weights(
path, module_name=None, ignore_modules=None, mapping=None, map_location=None
) -> Tuple[dict, int, int]:
if module_name is not None and ignore_modules is not None:
raise ValueError("module_name and ignore_modules cannot be both set")
if map_location is None:
map_location = get_device()
ckpt = torch.load(path, map_location=map_location)
state_dict = ckpt["state_dict"]
if mapping is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
if any([k.startswith(m["to"]) for m in mapping]):
pass
else:
state_dict_to_load[k] = v
for k, v in state_dict.items():
for m in mapping:
if k.startswith(m["from"]):
k_dest = k.replace(m["from"], m["to"])
info(f"Mapping {k} => {k_dest}")
state_dict_to_load[k_dest] = v.clone()
state_dict = state_dict_to_load
state_dict_to_load = state_dict
if ignore_modules is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
ignore = any(
[k.startswith(ignore_module + ".") for ignore_module in ignore_modules]
)
if ignore:
continue
state_dict_to_load[k] = v
if module_name is not None:
state_dict_to_load = {}
for k, v in state_dict.items():
m = re.match(rf"^{module_name}\.(.*)$", k)
if m is None:
continue
state_dict_to_load[m.group(1)] = v
return state_dict_to_load, ckpt["epoch"], ckpt["global_step"]
def C(value: Any, epoch: int, global_step: int) -> float:
if isinstance(value, int) or isinstance(value, float):
pass
else:
value = config_to_primitive(value)
if not isinstance(value, list):
raise TypeError("Scalar specification only supports list, got", type(value))
if len(value) == 3:
value = [0] + value
assert len(value) == 4
start_step, start_value, end_value, end_step = value
if isinstance(end_step, int):
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
elif isinstance(end_step, float):
current_step = epoch
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def cleanup():
gc.collect()
torch.cuda.empty_cache()
try:
import tinycudann as tcnn
tcnn.free_temporary_memory()
except:
pass
def finish_with_cleanup(func: Callable):
def wrapper(*args, **kwargs):
out = func(*args, **kwargs)
cleanup()
return out
return wrapper
def _distributed_available():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def barrier():
if not _distributed_available():
return
else:
torch.distributed.barrier()
def broadcast(tensor, src=0):
if not _distributed_available():
return tensor
else:
torch.distributed.broadcast(tensor, src=src)
return tensor
def enable_gradient(model, enabled: bool = True) -> None:
for param in model.parameters():
param.requires_grad_(enabled)
class TimeRecorder:
_instance = None
def __init__(self):
self.items = {}
self.accumulations = defaultdict(list)
self.time_scale = 1000.0 # ms
self.time_unit = "ms"
self.enabled = False
def __new__(cls):
# singleton
if cls._instance is None:
cls._instance = super(TimeRecorder, cls).__new__(cls)
return cls._instance
def enable(self, enabled: bool) -> None:
self.enabled = enabled
def start(self, name: str) -> None:
if not self.enabled:
return
torch.cuda.synchronize()
self.items[name] = time.time()
def end(self, name: str, accumulate: bool = False) -> float:
if not self.enabled or name not in self.items:
return
torch.cuda.synchronize()
start_time = self.items.pop(name)
delta = time.time() - start_time
if accumulate:
self.accumulations[name].append(delta)
t = delta * self.time_scale
info(f"{name}: {t:.2f}{self.time_unit}")
def get_accumulation(self, name: str, average: bool = False) -> float:
if not self.enabled or name not in self.accumulations:
return
acc = self.accumulations.pop(name)
total = sum(acc)
if average:
t = total / len(acc) * self.time_scale
else:
t = total * self.time_scale
info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}")
### global time recorder
time_recorder = TimeRecorder()
@contextmanager
def time_recorder_enabled():
enabled = time_recorder.enabled
time_recorder.enable(enabled=True)
try:
yield
finally:
time_recorder.enable(enabled=enabled)
def show_vram_usage(name):
available, total = torch.cuda.mem_get_info()
used = total - available
print(
f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB"
)