Spaces:
Configuration error
Configuration error
File size: 6,792 Bytes
624f97e bf48682 35eafc3 624f97e bf48682 624f97e ba81a8e 7cc3853 bf48682 35eafc3 7cc3853 ba81a8e 35eafc3 bf48682 624f97e 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 624f97e 7cc3853 35eafc3 7cc3853 35eafc3 624f97e 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 35eafc3 7cc3853 bf48682 35eafc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from __future__ import annotations
from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING
from faster_whisper import WhisperModel
from speaches.hf_utils import get_piper_voice_model_file
if TYPE_CHECKING:
from collections.abc import Callable
from piper.voice import PiperVoice
from speaches.config import (
WhisperConfig,
)
logger = logging.getLogger(__name__)
# TODO: enable concurrent model downloads
class SelfDisposingModel[T]:
def __init__(
self, model_id: str, load_fn: Callable[[], T], ttl: int, unload_fn: Callable[[str], None] | None = None
) -> None:
self.model_id = model_id
self.load_fn = load_fn
self.ttl = ttl
self.unload_fn = unload_fn
self.ref_count: int = 0
self.rlock = threading.RLock()
self.expire_timer: threading.Timer | None = None
self.model: T | None = None
def unload(self) -> None:
with self.rlock:
if self.model is None:
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
if self.ref_count > 0:
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
if self.expire_timer:
self.expire_timer.cancel()
self.model = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.unload_fn is not None:
self.unload_fn(self.model_id)
def _load(self) -> None:
with self.rlock:
assert self.model is None
logger.debug(f"Loading model {self.model_id}")
start = time.perf_counter()
self.model = self.load_fn()
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
def _increment_ref(self) -> None:
with self.rlock:
self.ref_count += 1
if self.expire_timer:
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
self.expire_timer.cancel()
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
def _decrement_ref(self) -> None:
with self.rlock:
self.ref_count -= 1
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
if self.ref_count <= 0:
if self.ttl > 0:
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.ttl}s")
self.expire_timer = threading.Timer(self.ttl, self.unload)
self.expire_timer.start()
elif self.ttl == 0:
logger.info(f"Model {self.model_id} is idle, unloading immediately")
self.unload()
else:
logger.info(f"Model {self.model_id} is idle, not unloading")
def __enter__(self) -> T:
with self.rlock:
if self.model is None:
self._load()
self._increment_ref()
assert self.model is not None
return self.model
def __exit__(self, *_args) -> None: # noqa: ANN002
self._decrement_ref()
class WhisperModelManager:
def __init__(self, whisper_config: WhisperConfig) -> None:
self.whisper_config = whisper_config
self.loaded_models: OrderedDict[str, SelfDisposingModel[WhisperModel]] = OrderedDict()
self._lock = threading.Lock()
def _load_fn(self, model_id: str) -> WhisperModel:
return WhisperModel(
model_id,
device=self.whisper_config.inference_device,
device_index=self.whisper_config.device_index,
compute_type=self.whisper_config.compute_type,
cpu_threads=self.whisper_config.cpu_threads,
num_workers=self.whisper_config.num_workers,
)
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]:
logger.debug(f"Loading model {model_name}")
with self._lock:
logger.debug("Acquired lock")
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingModel[WhisperModel](
model_name,
load_fn=lambda: self._load_fn(model_name),
ttl=self.whisper_config.ttl,
unload_fn=self._handle_model_unload,
)
return self.loaded_models[model_name]
class PiperModelManager:
def __init__(self, ttl: int) -> None:
self.ttl = ttl
self.loaded_models: OrderedDict[str, SelfDisposingModel[PiperVoice]] = OrderedDict()
self._lock = threading.Lock()
def _load_fn(self, model_id: str) -> PiperVoice:
from piper.voice import PiperVoice
model_path = get_piper_voice_model_file(model_id)
return PiperVoice.load(model_path)
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]:
from piper.voice import PiperVoice
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingModel[PiperVoice](
model_name,
load_fn=lambda: self._load_fn(model_name),
ttl=self.ttl,
unload_fn=self._handle_model_unload,
)
return self.loaded_models[model_name]
|