|  | """Benchmarking and measurement utilities""" | 
					
						
						|  | import functools | 
					
						
						|  |  | 
					
						
						|  | import pynvml | 
					
						
						|  | import torch | 
					
						
						|  | from pynvml.nvml import NVMLError | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_cuda_device(default_value): | 
					
						
						|  | """ | 
					
						
						|  | wraps a function and returns the default value instead of running the | 
					
						
						|  | wrapped function if cuda isn't available or the device is auto | 
					
						
						|  | :param default_value: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def deco(func): | 
					
						
						|  | @functools.wraps(func) | 
					
						
						|  | def wrapper(*args, **kwargs): | 
					
						
						|  | device = kwargs.get("device", args[0] if args else None) | 
					
						
						|  |  | 
					
						
						|  | if ( | 
					
						
						|  | device is None | 
					
						
						|  | or not torch.cuda.is_available() | 
					
						
						|  | or device == "auto" | 
					
						
						|  | or torch.device(device).type == "cpu" | 
					
						
						|  | or torch.device(device).type == "meta" | 
					
						
						|  | ): | 
					
						
						|  | return default_value | 
					
						
						|  | return func(*args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  | return wrapper | 
					
						
						|  |  | 
					
						
						|  | return deco | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @check_cuda_device(0.0) | 
					
						
						|  | def gpu_memory_usage(device=0): | 
					
						
						|  | return torch.cuda.memory_allocated(device) / 1024.0**3 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @check_cuda_device((0.0, 0.0, 0.0)) | 
					
						
						|  | def gpu_memory_usage_all(device=0): | 
					
						
						|  | usage = torch.cuda.memory_allocated(device) / 1024.0**3 | 
					
						
						|  | reserved = torch.cuda.memory_reserved(device) / 1024.0**3 | 
					
						
						|  | smi = gpu_memory_usage_smi(device) | 
					
						
						|  | return usage, reserved - usage, max(0, smi - reserved) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mps_memory_usage_all(): | 
					
						
						|  | usage = torch.mps.current_allocated_memory() / 1024.0**3 | 
					
						
						|  | reserved = torch.mps.driver_allocated_memory() / 1024.0**3 | 
					
						
						|  | return usage, reserved - usage, 0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @check_cuda_device(0.0) | 
					
						
						|  | def gpu_memory_usage_smi(device=0): | 
					
						
						|  | if isinstance(device, torch.device): | 
					
						
						|  | device = device.index | 
					
						
						|  | if isinstance(device, str) and device.startswith("cuda:"): | 
					
						
						|  | device = int(device[5:]) | 
					
						
						|  | try: | 
					
						
						|  | pynvml.nvmlInit() | 
					
						
						|  | handle = pynvml.nvmlDeviceGetHandleByIndex(device) | 
					
						
						|  | info = pynvml.nvmlDeviceGetMemoryInfo(handle) | 
					
						
						|  | return info.used / 1024.0**3 | 
					
						
						|  | except NVMLError: | 
					
						
						|  | return 0.0 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def log_gpu_memory_usage(log, msg, device): | 
					
						
						|  | if torch.backends.mps.is_available(): | 
					
						
						|  | usage, cache, misc = mps_memory_usage_all() | 
					
						
						|  | else: | 
					
						
						|  | usage, cache, misc = gpu_memory_usage_all(device) | 
					
						
						|  | extras = [] | 
					
						
						|  | if cache > 0: | 
					
						
						|  | extras.append(f"+{cache:.03f}GB cache") | 
					
						
						|  | if misc > 0: | 
					
						
						|  | extras.append(f"+{misc:.03f}GB misc") | 
					
						
						|  | log.info( | 
					
						
						|  | f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 | 
					
						
						|  | ) | 
					
						
						|  | return usage, cache, misc | 
					
						
						|  |  |