Realcat
update: model cache
bd20887
raw
history blame
13.1 kB
import hashlib
import json
import time
import threading
from collections import OrderedDict
import torch
from ..hloc import logger
class ARCSizeAwareModelCache:
def __init__(
self,
max_gpu_mem: float = 12e9,
max_cpu_mem: float = 12e9,
device_priority: list = ["cuda", "cpu"],
auto_empty_cache: bool = True,
):
"""
Initialize the model cache.
Args:
max_gpu_mem: Maximum GPU memory allowed in bytes.
max_cpu_mem: Maximum CPU memory allowed in bytes.
device_priority: List of devices to prioritize when evicting models.
auto_empty_cache: Whether to call torch.cuda.empty_cache() when out of memory.
"""
self.t1 = OrderedDict()
self.t2 = OrderedDict()
self.b1 = OrderedDict()
self.b2 = OrderedDict()
self.max_gpu = max_gpu_mem
self.max_cpu = max_cpu_mem
self.current_gpu = 0
self.current_cpu = 0
self.p = 0
self.adaptive_factor = 0.5
self.device_priority = device_priority
self.lock = threading.Lock()
self.auto_empty_cache = auto_empty_cache
logger.info("ARCSizeAwareModelCache initialized.")
def _release_model(self, model_entry):
"""
Release a model from memory.
Args:
model_entry: A dictionary containing the model, device and other information.
Notes:
If the device is CUDA and auto_empty_cache is True, torch.cuda.empty_cache() is called after releasing the model.
"""
model = model_entry["model"]
device = model_entry["device"]
del model
if device == "cuda":
torch.cuda.synchronize()
if self.auto_empty_cache:
torch.cuda.empty_cache()
def generate_key(self, model_key, model_conf: dict) -> str:
loader_identifier = f"{model_key}"
unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}"
return hashlib.sha256(unique_str.encode()).hexdigest()
def _get_device(self, model_size: int) -> str:
for device in self.device_priority:
if device == "cuda" and torch.cuda.is_available():
if self.current_gpu + model_size <= self.max_gpu:
return "cuda"
elif device == "cpu":
if self.current_cpu + model_size <= self.max_cpu:
return "cpu"
return "cpu"
def _calculate_model_size(self, model):
return sum(p.numel() * p.element_size() for p in model.parameters()) + sum(
b.numel() * b.element_size() for b in model.buffers()
)
def _update_access(self, key: str, size: int, device: str):
if key in self.b1:
self.p = min(
self.p + max(1, len(self.b2) // len(self.b1)),
len(self.t1) + len(self.t2),
)
self.b1.pop(key)
self._replace(False)
elif key in self.b2:
self.p = max(self.p - max(1, len(self.b1) // len(self.b2)), 0)
self.b2.pop(key)
self._replace(True)
if key in self.t1:
self.t1.pop(key)
self.t2[key] = {
"size": size,
"device": device,
"access_count": 1,
"last_accessed": time.time(),
}
def _replace(self, in_t2: bool):
if len(self.t1) > 0 and (
(len(self.t1) > self.p) or (in_t2 and len(self.t1) == self.p)
):
k, v = self.t1.popitem(last=False)
self.b1[k] = v
else:
k, v = self.t2.popitem(last=False)
self.b2[k] = v
def _calculate_weight(self, entry) -> float:
return entry["access_count"] / entry["size"]
def _evict_models(self, required_size: int, target_device: str) -> bool:
candidates = []
for k, v in list(self.t1.items()) + list(self.t2.items()):
if v["device"] == target_device:
candidates.append((k, v))
candidates.sort(key=lambda x: self._calculate_weight(x[1]))
freed = 0
for k, v in candidates:
self._release_model(v)
freed += v["size"]
if v in self.t1:
self.t1.pop(k)
if v in self.t2:
self.t2.pop(k)
if v["device"] == "cuda":
self.current_gpu -= v["size"]
else:
self.current_cpu -= v["size"]
if freed >= required_size:
return True
if target_device == "cuda":
return self._cross_device_evict(required_size, "cuda")
return False
def _cross_device_evict(self, required_size: int, target_device: str) -> bool:
all_entries = []
for k, v in list(self.t1.items()) + list(self.t2.items()):
all_entries.append((k, v))
all_entries.sort(
key=lambda x: self._calculate_weight(x[1])
+ (0.5 if x[1]["device"] == target_device else 0)
)
freed = 0
for k, v in all_entries:
freed += v["size"]
if v in self.t1:
self.t1.pop(k)
if v in self.t2:
self.t2.pop(k)
if v["device"] == "cuda":
self.current_gpu -= v["size"]
else:
self.current_cpu -= v["size"]
if freed >= required_size:
return True
return False
def load_model(self, model_key, model_loader_func, model_conf: dict):
key = self.generate_key(model_key, model_conf)
with self.lock:
if key in self.t1 or key in self.t2:
entry = self.t1.pop(key, None) or self.t2.pop(key)
entry["access_count"] += 1
self.t2[key] = entry
return entry["model"]
raw_model = model_loader_func(model_conf)
model_size = self._calculate_model_size(raw_model)
device = self._get_device(model_size)
if device == "cuda" and self.auto_empty_cache:
torch.cuda.empty_cache()
torch.cuda.synchronize()
while True:
current_mem = self.current_gpu if device == "cuda" else self.current_cpu
max_mem = self.max_gpu if device == "cuda" else self.max_cpu
if current_mem + model_size <= max_mem:
break
if not self._evict_models(model_size, device):
if device == "cuda":
device = "cpu"
else:
raise RuntimeError("Out of memory")
try:
model = raw_model.to(device)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
torch.cuda.empty_cache()
model = raw_model.to(device)
new_entry = {
"model": model,
"size": model_size,
"device": device,
"access_count": 1,
"last_accessed": time.time(),
}
if key in self.b1 or key in self.b2:
self.t2[key] = new_entry
self._replace(True)
else:
self.t1[key] = new_entry
self._replace(False)
if device == "cuda":
self.current_gpu += model_size
else:
self.current_cpu += model_size
return model
def clear_device_cache(self, device: str):
with self.lock:
for cache in [self.t1, self.t2, self.b1, self.b2]:
for k in list(cache.keys()):
if cache[k]["device"] == device:
cache.pop(k)
class LRUModelCache:
def __init__(
self,
max_gpu_mem: float = 8e9,
max_cpu_mem: float = 12e9,
device_priority: list = ["cuda", "cpu"],
):
self.cache = OrderedDict()
self.max_gpu = max_gpu_mem
self.max_cpu = max_cpu_mem
self.current_gpu = 0
self.current_cpu = 0
self.lock = threading.Lock()
self.device_priority = device_priority
def generate_key(self, model_key, model_conf: dict) -> str:
loader_identifier = f"{model_key}"
unique_str = f"{loader_identifier}-{json.dumps(model_conf, sort_keys=True)}"
return hashlib.sha256(unique_str.encode()).hexdigest()
def get_device(self) -> str:
for device in self.device_priority:
if device == "cuda" and torch.cuda.is_available():
if self.current_gpu < self.max_gpu:
return device
elif device == "cpu":
if self.current_cpu < self.max_cpu:
return device
return "cpu"
def _calculate_model_size(self, model):
param_size = sum(p.numel() * p.element_size() for p in model.parameters())
buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
return param_size + buffer_size
def load_model(self, model_key, model_loader_func, model_conf: dict):
key = self.generate_key(model_key, model_conf)
with self.lock:
if key in self.cache:
self.cache.move_to_end(key) # update LRU
return self.cache[key]["model"]
device = self.get_device()
if device == "cuda":
torch.cuda.empty_cache()
try:
raw_model = model_loader_func(model_conf)
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
try:
model = raw_model.to(device)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
return self._handle_oom(model_key, model_loader_func, model_conf)
raise
model_size = self._calculate_model_size(model)
while (
device == "cuda" and (self.current_gpu + model_size > self.max_gpu)
) or (device == "cpu" and (self.current_cpu + model_size > self.max_cpu)):
if not self._free_space(model_size, device):
raise RuntimeError("Insufficient memory even after cache cleanup")
if device == "cuda":
self.current_gpu += model_size
else:
self.current_cpu += model_size
self.cache[key] = {
"model": model,
"size": model_size,
"device": device,
"timestamp": time.time(),
}
return model
def _free_space(self, required_size: int, device: str) -> bool:
for key in list(self.cache.keys()):
if (device == "cuda" and self.cache[key]["device"] == "cuda") or (
device == "cpu" and self.cache[key]["device"] == "cpu"
):
self.current_gpu -= (
self.cache[key]["size"]
if self.cache[key]["device"] == "cuda"
else 0
)
self.current_cpu -= (
self.cache[key]["size"] if self.cache[key]["device"] == "cpu" else 0
)
del self.cache[key]
if (
device == "cuda"
and self.current_gpu + required_size <= self.max_gpu
) or (
device == "cpu" and self.current_cpu + required_size <= self.max_cpu
):
return True
return False
def _handle_oom(self, model_key, model_loader_func, model_conf: dict):
with self.lock:
self.clear_device_cache("cuda")
torch.cuda.empty_cache()
try:
return self.load_model(model_key, model_loader_func, model_conf)
except RuntimeError:
original_priority = self.device_priority
self.device_priority = ["cpu"]
try:
return self.load_model(model_key, model_loader_func, model_conf)
finally:
self.device_priority = original_priority
def clear_device_cache(self, device: str):
with self.lock:
keys_to_remove = [k for k, v in self.cache.items() if v["device"] == device]
for k in keys_to_remove:
self.current_gpu -= self.cache[k]["size"] if device == "cuda" else 0
self.current_cpu -= self.cache[k]["size"] if device == "cpu" else 0
del self.cache[k]