Spaces:
Running
Running
File size: 7,255 Bytes
529ed6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path as osp
import platform
import subprocess
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
def none_or_int(value):
if value == "None":
return None
return int(value)
def inside_slurm():
"""Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash`
return "SLURM_JOB_ID" in os.environ
def auto_select_torch_device() -> torch.device:
"""Tries to select automatically a torch device."""
if torch.cuda.is_available():
logging.info("Cuda backend detected, using cuda.")
return torch.device("cuda")
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(try_device)
if log:
logging.warning(f"Using custom {try_device} device.")
return device
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
"""
mps is currently not compatible with float64
"""
if isinstance(device, torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
else:
return dtype
def is_torch_device_available(try_device: str) -> bool:
try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
return True
elif device == "mps":
return False
else:
raise ValueError(f"Unknown device '{device}.")
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
formatter = logging.Formatter()
formatter.format = custom_format
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
def format_big_number(num, precision=0):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.{precision}f}{suffix}"
num /= divisor
return num
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
def capture_timestamp_utc():
return datetime.now(timezone.utc)
def say(text, blocking=False):
system = platform.system()
if system == "Darwin":
cmd = ["say", text]
elif system == "Linux":
cmd = ["spd-say", text]
if blocking:
cmd.append("--wait")
elif system == "Windows":
cmd = [
"PowerShell",
"-Command",
"Add-Type -AssemblyName System.Speech; "
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
]
else:
raise RuntimeError("Unsupported operating system for text-to-speech.")
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
logging.info(text)
if play_sounds:
say(text, blocking)
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
shape = copy(image_shape)
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
raise ValueError(image_shape)
return shape
def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
"""
try:
# Attempt to convert the string to a numpy dtype
np.dtype(dtype_str)
return True
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False
|