Kevin Hu
commited on
Commit
·
11bef16
1
Parent(s):
3327e72
Fix fastembed reloading issue. (#4117)
Browse files### What problem does this PR solve?
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/embedding_model.py +19 -2
rag/llm/embedding_model.py
CHANGED
|
@@ -47,6 +47,7 @@ class Base(ABC):
|
|
| 47 |
|
| 48 |
class DefaultEmbedding(Base):
|
| 49 |
_model = None
|
|
|
|
| 50 |
_model_lock = threading.Lock()
|
| 51 |
def __init__(self, key, model_name, **kwargs):
|
| 52 |
"""
|
|
@@ -69,6 +70,7 @@ class DefaultEmbedding(Base):
|
|
| 69 |
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
| 70 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 71 |
use_fp16=torch.cuda.is_available())
|
|
|
|
| 72 |
except Exception:
|
| 73 |
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
| 74 |
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
|
@@ -77,6 +79,7 @@ class DefaultEmbedding(Base):
|
|
| 77 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 78 |
use_fp16=torch.cuda.is_available())
|
| 79 |
self._model = DefaultEmbedding._model
|
|
|
|
| 80 |
|
| 81 |
def encode(self, texts: list):
|
| 82 |
batch_size = 16
|
|
@@ -250,6 +253,8 @@ class OllamaEmbed(Base):
|
|
| 250 |
|
| 251 |
class FastEmbed(Base):
|
| 252 |
_model = None
|
|
|
|
|
|
|
| 253 |
|
| 254 |
def __init__(
|
| 255 |
self,
|
|
@@ -260,8 +265,20 @@ class FastEmbed(Base):
|
|
| 260 |
**kwargs,
|
| 261 |
):
|
| 262 |
if not settings.LIGHTEN and not FastEmbed._model:
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
def encode(self, texts: list):
|
| 267 |
# Using the internal tokenizer to encode the texts and get the total
|
|
|
|
| 47 |
|
| 48 |
class DefaultEmbedding(Base):
|
| 49 |
_model = None
|
| 50 |
+
_model_name = ""
|
| 51 |
_model_lock = threading.Lock()
|
| 52 |
def __init__(self, key, model_name, **kwargs):
|
| 53 |
"""
|
|
|
|
| 70 |
DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
| 71 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 72 |
use_fp16=torch.cuda.is_available())
|
| 73 |
+
DefaultEmbedding._model_name = model_name
|
| 74 |
except Exception:
|
| 75 |
model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5",
|
| 76 |
local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
|
|
|
| 79 |
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 80 |
use_fp16=torch.cuda.is_available())
|
| 81 |
self._model = DefaultEmbedding._model
|
| 82 |
+
self._model_name = DefaultEmbedding._model_name
|
| 83 |
|
| 84 |
def encode(self, texts: list):
|
| 85 |
batch_size = 16
|
|
|
|
| 253 |
|
| 254 |
class FastEmbed(Base):
|
| 255 |
_model = None
|
| 256 |
+
_model_name = ""
|
| 257 |
+
_model_lock = threading.Lock()
|
| 258 |
|
| 259 |
def __init__(
|
| 260 |
self,
|
|
|
|
| 265 |
**kwargs,
|
| 266 |
):
|
| 267 |
if not settings.LIGHTEN and not FastEmbed._model:
|
| 268 |
+
with FastEmbed._model_lock:
|
| 269 |
+
from fastembed import TextEmbedding
|
| 270 |
+
if not FastEmbed._model or model_name != FastEmbed._model_name:
|
| 271 |
+
try:
|
| 272 |
+
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 273 |
+
FastEmbed._model_name = model_name
|
| 274 |
+
except Exception:
|
| 275 |
+
cache_dir = snapshot_download(repo_id="BAAI/bge-small-en-v1.5",
|
| 276 |
+
local_dir=os.path.join(get_home_cache_dir(),
|
| 277 |
+
re.sub(r"^[a-zA-Z0-9]+/", "", model_name)),
|
| 278 |
+
local_dir_use_symlinks=False)
|
| 279 |
+
FastEmbed._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 280 |
+
self._model = FastEmbed._model
|
| 281 |
+
self._model_name = model_name
|
| 282 |
|
| 283 |
def encode(self, texts: list):
|
| 284 |
# Using the internal tokenizer to encode the texts and get the total
|