KevinHuSh
commited on
Commit
·
ba51460
1
Parent(s):
67dea7a
Add bce-embedding and fastembed (#383)
Browse files### What problem does this PR solve?
Issue link:#326
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- README.md +2 -0
- README_ja.md +2 -0
- README_zh.md +2 -0
- api/apps/chunk_app.py +1 -1
- api/apps/document_app.py +1 -0
- api/apps/llm_app.py +2 -2
- api/db/init_data.py +19 -11
- api/db/services/dialog_service.py +5 -1
- api/db/services/llm_service.py +8 -3
- api/db/services/task_service.py +1 -1
- rag/llm/__init__.py +2 -2
- rag/llm/embedding_model.py +49 -14
- rag/nlp/search.py +1 -1
- rag/svr/task_executor.py +2 -1
- requirements.txt +2 -0
README.md
CHANGED
|
@@ -55,6 +55,8 @@
|
|
| 55 |
|
| 56 |
## 📌 Latest Features
|
| 57 |
|
|
|
|
|
|
|
| 58 |
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
|
| 59 |
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
|
| 60 |
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
|
|
|
|
| 55 |
|
| 56 |
## 📌 Latest Features
|
| 57 |
|
| 58 |
+
- 2024-04-16 Add an embedding model 'bce-embedding-base_v1' from [QAnything](https://github.com/netease-youdao/QAnything).
|
| 59 |
+
- 2024-04-16 Add [FastEmbed](https://github.com/qdrant/fastembed) is designed for light and speeding embedding.
|
| 60 |
- 2024-04-11 Support [Xinference](./docs/xinference.md) for local LLM deployment.
|
| 61 |
- 2024-04-10 Add a new layout recognization model for analyzing Laws documentation.
|
| 62 |
- 2024-04-08 Support [Ollama](./docs/ollama.md) for local LLM deployment.
|
README_ja.md
CHANGED
|
@@ -55,6 +55,8 @@
|
|
| 55 |
|
| 56 |
## 📌 最新の機能
|
| 57 |
|
|
|
|
|
|
|
| 58 |
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
|
| 59 |
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
|
| 60 |
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
|
|
|
|
| 55 |
|
| 56 |
## 📌 最新の機能
|
| 57 |
|
| 58 |
+
- 2024-04-16 [QAnything](https://github.com/netease-youdao/QAnything) から埋め込みモデル「bce-embedding-base_v1」を追加します。
|
| 59 |
+
- 2024-04-16 [FastEmbed](https://github.com/qdrant/fastembed) は、軽量かつ高速な埋め込み用に設計されています。
|
| 60 |
- 2024-04-11 ローカル LLM デプロイメント用に [Xinference](./docs/xinference.md) をサポートします。
|
| 61 |
- 2024-04-10 メソッド「Laws」に新しいレイアウト認識モデルを追加します。
|
| 62 |
- 2024-04-08 [Ollama](./docs/ollama.md) を使用した大規模モデルのローカライズされたデプロイメントをサポートします。
|
README_zh.md
CHANGED
|
@@ -55,6 +55,8 @@
|
|
| 55 |
|
| 56 |
## 📌 新增功能
|
| 57 |
|
|
|
|
|
|
|
| 58 |
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
|
| 59 |
- 2024-04-10 为‘Laws’版面分析增加了底层模型。
|
| 60 |
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。
|
|
|
|
| 55 |
|
| 56 |
## 📌 新增功能
|
| 57 |
|
| 58 |
+
- 2024-04-16 添加嵌入模型 [QAnything的bce-embedding-base_v1](https://github.com/netease-youdao/QAnything) 。
|
| 59 |
+
- 2024-04-16 添加 [FastEmbed](https://github.com/qdrant/fastembed) 专为轻型和高速嵌入而设计。
|
| 60 |
- 2024-04-11 支持用 [Xinference](./docs/xinference.md) 本地化部署大模型。
|
| 61 |
- 2024-04-10 为‘Laws’版面分析增加了底层模型。
|
| 62 |
- 2024-04-08 支持用 [Ollama](./docs/ollama.md) 本地化部署大模型。
|
api/apps/chunk_app.py
CHANGED
|
@@ -252,7 +252,7 @@ def retrieval_test():
|
|
| 252 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 253 |
|
| 254 |
embd_mdl = TenantLLMService.model_instance(
|
| 255 |
-
kb.tenant_id, LLMType.EMBEDDING.value)
|
| 256 |
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
| 257 |
vector_similarity_weight, top, doc_ids)
|
| 258 |
for c in ranks["chunks"]:
|
|
|
|
| 252 |
return get_data_error_result(retmsg="Knowledgebase not found!")
|
| 253 |
|
| 254 |
embd_mdl = TenantLLMService.model_instance(
|
| 255 |
+
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
| 256 |
ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, similarity_threshold,
|
| 257 |
vector_similarity_weight, top, doc_ids)
|
| 258 |
for c in ranks["chunks"]:
|
api/apps/document_app.py
CHANGED
|
@@ -15,6 +15,7 @@
|
|
| 15 |
#
|
| 16 |
|
| 17 |
import base64
|
|
|
|
| 18 |
import pathlib
|
| 19 |
import re
|
| 20 |
|
|
|
|
| 15 |
#
|
| 16 |
|
| 17 |
import base64
|
| 18 |
+
import os
|
| 19 |
import pathlib
|
| 20 |
import re
|
| 21 |
|
api/apps/llm_app.py
CHANGED
|
@@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|
| 28 |
def factories():
|
| 29 |
try:
|
| 30 |
fac = LLMFactoriesService.get_all()
|
| 31 |
-
return get_json_result(data=[f.to_dict() for f in fac])
|
| 32 |
except Exception as e:
|
| 33 |
return server_error_response(e)
|
| 34 |
|
|
@@ -174,7 +174,7 @@ def list():
|
|
| 174 |
llms = [m.to_dict()
|
| 175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
| 176 |
for m in llms:
|
| 177 |
-
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding"
|
| 178 |
|
| 179 |
llm_set = set([m["llm_name"] for m in llms])
|
| 180 |
for o in objs:
|
|
|
|
| 28 |
def factories():
|
| 29 |
try:
|
| 30 |
fac = LLMFactoriesService.get_all()
|
| 31 |
+
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["QAnything", "FastEmbed"]])
|
| 32 |
except Exception as e:
|
| 33 |
return server_error_response(e)
|
| 34 |
|
|
|
|
| 174 |
llms = [m.to_dict()
|
| 175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
| 176 |
for m in llms:
|
| 177 |
+
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["QAnything","FastEmbed"]
|
| 178 |
|
| 179 |
llm_set = set([m["llm_name"] for m in llms])
|
| 180 |
for o in objs:
|
api/db/init_data.py
CHANGED
|
@@ -18,7 +18,7 @@ import time
|
|
| 18 |
import uuid
|
| 19 |
|
| 20 |
from api.db import LLMType, UserTenantRole
|
| 21 |
-
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM
|
| 22 |
from api.db.services import UserService
|
| 23 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
| 24 |
from api.db.services.user_service import TenantService, UserTenantService
|
|
@@ -114,12 +114,16 @@ factory_infos = [{
|
|
| 114 |
"logo": "",
|
| 115 |
"tags": "TEXT EMBEDDING",
|
| 116 |
"status": "1",
|
| 117 |
-
},
|
| 118 |
-
{
|
| 119 |
"name": "Xinference",
|
| 120 |
"logo": "",
|
| 121 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 122 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
},
|
| 124 |
# {
|
| 125 |
# "name": "文心一言",
|
|
@@ -254,12 +258,6 @@ def init_llm_factory():
|
|
| 254 |
"tags": "LLM,CHAT,",
|
| 255 |
"max_tokens": 7900,
|
| 256 |
"model_type": LLMType.CHAT.value
|
| 257 |
-
}, {
|
| 258 |
-
"fid": factory_infos[4]["name"],
|
| 259 |
-
"llm_name": "flag-embedding",
|
| 260 |
-
"tags": "TEXT EMBEDDING,",
|
| 261 |
-
"max_tokens": 128 * 1000,
|
| 262 |
-
"model_type": LLMType.EMBEDDING.value
|
| 263 |
}, {
|
| 264 |
"fid": factory_infos[4]["name"],
|
| 265 |
"llm_name": "moonshot-v1-32k",
|
|
@@ -325,6 +323,14 @@ def init_llm_factory():
|
|
| 325 |
"max_tokens": 2147483648,
|
| 326 |
"model_type": LLMType.EMBEDDING.value
|
| 327 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
]
|
| 329 |
for info in factory_infos:
|
| 330 |
try:
|
|
@@ -337,8 +343,10 @@ def init_llm_factory():
|
|
| 337 |
except Exception as e:
|
| 338 |
pass
|
| 339 |
|
| 340 |
-
LLMFactoriesService.filter_delete([LLMFactories.name=="Local"])
|
| 341 |
-
LLMService.filter_delete([LLM.fid=="Local"])
|
|
|
|
|
|
|
| 342 |
|
| 343 |
"""
|
| 344 |
drop table llm;
|
|
|
|
| 18 |
import uuid
|
| 19 |
|
| 20 |
from api.db import LLMType, UserTenantRole
|
| 21 |
+
from api.db.db_models import init_database_tables as init_web_db, LLMFactories, LLM, TenantLLM
|
| 22 |
from api.db.services import UserService
|
| 23 |
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
| 24 |
from api.db.services.user_service import TenantService, UserTenantService
|
|
|
|
| 114 |
"logo": "",
|
| 115 |
"tags": "TEXT EMBEDDING",
|
| 116 |
"status": "1",
|
| 117 |
+
}, {
|
|
|
|
| 118 |
"name": "Xinference",
|
| 119 |
"logo": "",
|
| 120 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 121 |
"status": "1",
|
| 122 |
+
},{
|
| 123 |
+
"name": "QAnything",
|
| 124 |
+
"logo": "",
|
| 125 |
+
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 126 |
+
"status": "1",
|
| 127 |
},
|
| 128 |
# {
|
| 129 |
# "name": "文心一言",
|
|
|
|
| 258 |
"tags": "LLM,CHAT,",
|
| 259 |
"max_tokens": 7900,
|
| 260 |
"model_type": LLMType.CHAT.value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
}, {
|
| 262 |
"fid": factory_infos[4]["name"],
|
| 263 |
"llm_name": "moonshot-v1-32k",
|
|
|
|
| 323 |
"max_tokens": 2147483648,
|
| 324 |
"model_type": LLMType.EMBEDDING.value
|
| 325 |
},
|
| 326 |
+
# ------------------------ QAnything -----------------------
|
| 327 |
+
{
|
| 328 |
+
"fid": factory_infos[7]["name"],
|
| 329 |
+
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
| 330 |
+
"tags": "TEXT EMBEDDING,",
|
| 331 |
+
"max_tokens": 512,
|
| 332 |
+
"model_type": LLMType.EMBEDDING.value
|
| 333 |
+
},
|
| 334 |
]
|
| 335 |
for info in factory_infos:
|
| 336 |
try:
|
|
|
|
| 343 |
except Exception as e:
|
| 344 |
pass
|
| 345 |
|
| 346 |
+
LLMFactoriesService.filter_delete([LLMFactories.name == "Local"])
|
| 347 |
+
LLMService.filter_delete([LLM.fid == "Local"])
|
| 348 |
+
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
| 349 |
+
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
| 350 |
|
| 351 |
"""
|
| 352 |
drop table llm;
|
api/db/services/dialog_service.py
CHANGED
|
@@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs):
|
|
| 80 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 81 |
max_tokens = 1024
|
| 82 |
else: max_tokens = llm[0].max_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
| 84 |
-
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
| 85 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 86 |
|
| 87 |
prompt_config = dialog.prompt_config
|
|
|
|
| 80 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 81 |
max_tokens = 1024
|
| 82 |
else: max_tokens = llm[0].max_tokens
|
| 83 |
+
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 84 |
+
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 85 |
+
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
| 86 |
+
|
| 87 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
| 88 |
+
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
| 89 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 90 |
|
| 91 |
prompt_config = dialog.prompt_config
|
api/db/services/llm_service.py
CHANGED
|
@@ -66,7 +66,7 @@ class TenantLLMService(CommonService):
|
|
| 66 |
raise LookupError("Tenant not found")
|
| 67 |
|
| 68 |
if llm_type == LLMType.EMBEDDING.value:
|
| 69 |
-
mdlnm = tenant.embd_id
|
| 70 |
elif llm_type == LLMType.SPEECH2TEXT.value:
|
| 71 |
mdlnm = tenant.asr_id
|
| 72 |
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
@@ -77,9 +77,14 @@ class TenantLLMService(CommonService):
|
|
| 77 |
assert False, "LLM type error"
|
| 78 |
|
| 79 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
|
|
| 80 |
if not model_config:
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
if llm_type == LLMType.EMBEDDING.value:
|
| 84 |
if model_config["llm_factory"] not in EmbeddingModel:
|
| 85 |
return
|
|
|
|
| 66 |
raise LookupError("Tenant not found")
|
| 67 |
|
| 68 |
if llm_type == LLMType.EMBEDDING.value:
|
| 69 |
+
mdlnm = tenant.embd_id if not llm_name else llm_name
|
| 70 |
elif llm_type == LLMType.SPEECH2TEXT.value:
|
| 71 |
mdlnm = tenant.asr_id
|
| 72 |
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
|
|
| 77 |
assert False, "LLM type error"
|
| 78 |
|
| 79 |
model_config = cls.get_api_key(tenant_id, mdlnm)
|
| 80 |
+
if model_config: model_config = model_config.to_dict()
|
| 81 |
if not model_config:
|
| 82 |
+
if llm_type == LLMType.EMBEDDING.value:
|
| 83 |
+
llm = LLMService.query(llm_name=llm_name)
|
| 84 |
+
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
|
| 85 |
+
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
| 86 |
+
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
|
| 87 |
+
|
| 88 |
if llm_type == LLMType.EMBEDDING.value:
|
| 89 |
if model_config["llm_factory"] not in EmbeddingModel:
|
| 90 |
return
|
api/db/services/task_service.py
CHANGED
|
@@ -41,7 +41,7 @@ class TaskService(CommonService):
|
|
| 41 |
Document.size,
|
| 42 |
Knowledgebase.tenant_id,
|
| 43 |
Knowledgebase.language,
|
| 44 |
-
|
| 45 |
Tenant.img2txt_id,
|
| 46 |
Tenant.asr_id,
|
| 47 |
cls.model.update_time]
|
|
|
|
| 41 |
Document.size,
|
| 42 |
Knowledgebase.tenant_id,
|
| 43 |
Knowledgebase.language,
|
| 44 |
+
Knowledgebase.embd_id,
|
| 45 |
Tenant.img2txt_id,
|
| 46 |
Tenant.asr_id,
|
| 47 |
cls.model.update_time]
|
rag/llm/__init__.py
CHANGED
|
@@ -24,8 +24,8 @@ EmbeddingModel = {
|
|
| 24 |
"Xinference": XinferenceEmbed,
|
| 25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
| 26 |
"ZHIPU-AI": ZhipuEmbed,
|
| 27 |
-
"
|
| 28 |
-
"
|
| 29 |
}
|
| 30 |
|
| 31 |
|
|
|
|
| 24 |
"Xinference": XinferenceEmbed,
|
| 25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
| 26 |
"ZHIPU-AI": ZhipuEmbed,
|
| 27 |
+
"FastEmbed": FastEmbed,
|
| 28 |
+
"QAnything": QAnythingEmbed
|
| 29 |
}
|
| 30 |
|
| 31 |
|
rag/llm/embedding_model.py
CHANGED
|
@@ -20,7 +20,6 @@ from abc import ABC
|
|
| 20 |
from ollama import Client
|
| 21 |
import dashscope
|
| 22 |
from openai import OpenAI
|
| 23 |
-
from fastembed import TextEmbedding
|
| 24 |
from FlagEmbedding import FlagModel
|
| 25 |
import torch
|
| 26 |
import numpy as np
|
|
@@ -28,16 +27,17 @@ import numpy as np
|
|
| 28 |
from api.utils.file_utils import get_project_base_directory
|
| 29 |
from rag.utils import num_tokens_from_string
|
| 30 |
|
|
|
|
| 31 |
try:
|
| 32 |
flag_model = FlagModel(os.path.join(
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
except Exception as e:
|
| 38 |
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
| 39 |
-
|
| 40 |
-
|
| 41 |
|
| 42 |
|
| 43 |
class Base(ABC):
|
|
@@ -82,8 +82,10 @@ class HuEmbedding(Base):
|
|
| 82 |
|
| 83 |
|
| 84 |
class OpenAIEmbed(Base):
|
| 85 |
-
def __init__(self, key, model_name="text-embedding-ada-002",
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 88 |
self.model_name = model_name
|
| 89 |
|
|
@@ -142,7 +144,7 @@ class ZhipuEmbed(Base):
|
|
| 142 |
tks_num = 0
|
| 143 |
for txt in texts:
|
| 144 |
res = self.client.embeddings.create(input=txt,
|
| 145 |
-
|
| 146 |
arr.append(res.data[0].embedding)
|
| 147 |
tks_num += res.usage.total_tokens
|
| 148 |
return np.array(arr), tks_num
|
|
@@ -163,14 +165,14 @@ class OllamaEmbed(Base):
|
|
| 163 |
tks_num = 0
|
| 164 |
for txt in texts:
|
| 165 |
res = self.client.embeddings(prompt=txt,
|
| 166 |
-
|
| 167 |
arr.append(res["embedding"])
|
| 168 |
tks_num += 128
|
| 169 |
return np.array(arr), tks_num
|
| 170 |
|
| 171 |
def encode_queries(self, text):
|
| 172 |
res = self.client.embeddings(prompt=text,
|
| 173 |
-
|
| 174 |
return np.array(res["embedding"]), 128
|
| 175 |
|
| 176 |
|
|
@@ -183,10 +185,12 @@ class FastEmbed(Base):
|
|
| 183 |
threads: Optional[int] = None,
|
| 184 |
**kwargs,
|
| 185 |
):
|
|
|
|
| 186 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 187 |
|
| 188 |
def encode(self, texts: list, batch_size=32):
|
| 189 |
-
# Using the internal tokenizer to encode the texts and get the total
|
|
|
|
| 190 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
| 191 |
total_tokens = sum(len(e) for e in encodings)
|
| 192 |
|
|
@@ -195,7 +199,8 @@ class FastEmbed(Base):
|
|
| 195 |
return np.array(embeddings), total_tokens
|
| 196 |
|
| 197 |
def encode_queries(self, text: str):
|
| 198 |
-
# Using the internal tokenizer to encode the texts and get the total
|
|
|
|
| 199 |
encoding = self._model.model.tokenizer.encode(text)
|
| 200 |
embedding = next(self._model.query_embed(text)).tolist()
|
| 201 |
|
|
@@ -218,3 +223,33 @@ class XinferenceEmbed(Base):
|
|
| 218 |
model=self.model_name)
|
| 219 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
from ollama import Client
|
| 21 |
import dashscope
|
| 22 |
from openai import OpenAI
|
|
|
|
| 23 |
from FlagEmbedding import FlagModel
|
| 24 |
import torch
|
| 25 |
import numpy as np
|
|
|
|
| 27 |
from api.utils.file_utils import get_project_base_directory
|
| 28 |
from rag.utils import num_tokens_from_string
|
| 29 |
|
| 30 |
+
|
| 31 |
try:
|
| 32 |
flag_model = FlagModel(os.path.join(
|
| 33 |
+
get_project_base_directory(),
|
| 34 |
+
"rag/res/bge-large-zh-v1.5"),
|
| 35 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 36 |
+
use_fp16=torch.cuda.is_available())
|
| 37 |
except Exception as e:
|
| 38 |
flag_model = FlagModel("BAAI/bge-large-zh-v1.5",
|
| 39 |
+
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
|
| 40 |
+
use_fp16=torch.cuda.is_available())
|
| 41 |
|
| 42 |
|
| 43 |
class Base(ABC):
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
class OpenAIEmbed(Base):
|
| 85 |
+
def __init__(self, key, model_name="text-embedding-ada-002",
|
| 86 |
+
base_url="https://api.openai.com/v1"):
|
| 87 |
+
if not base_url:
|
| 88 |
+
base_url = "https://api.openai.com/v1"
|
| 89 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 90 |
self.model_name = model_name
|
| 91 |
|
|
|
|
| 144 |
tks_num = 0
|
| 145 |
for txt in texts:
|
| 146 |
res = self.client.embeddings.create(input=txt,
|
| 147 |
+
model=self.model_name)
|
| 148 |
arr.append(res.data[0].embedding)
|
| 149 |
tks_num += res.usage.total_tokens
|
| 150 |
return np.array(arr), tks_num
|
|
|
|
| 165 |
tks_num = 0
|
| 166 |
for txt in texts:
|
| 167 |
res = self.client.embeddings(prompt=txt,
|
| 168 |
+
model=self.model_name)
|
| 169 |
arr.append(res["embedding"])
|
| 170 |
tks_num += 128
|
| 171 |
return np.array(arr), tks_num
|
| 172 |
|
| 173 |
def encode_queries(self, text):
|
| 174 |
res = self.client.embeddings(prompt=text,
|
| 175 |
+
model=self.model_name)
|
| 176 |
return np.array(res["embedding"]), 128
|
| 177 |
|
| 178 |
|
|
|
|
| 185 |
threads: Optional[int] = None,
|
| 186 |
**kwargs,
|
| 187 |
):
|
| 188 |
+
from fastembed import TextEmbedding
|
| 189 |
self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs)
|
| 190 |
|
| 191 |
def encode(self, texts: list, batch_size=32):
|
| 192 |
+
# Using the internal tokenizer to encode the texts and get the total
|
| 193 |
+
# number of tokens
|
| 194 |
encodings = self._model.model.tokenizer.encode_batch(texts)
|
| 195 |
total_tokens = sum(len(e) for e in encodings)
|
| 196 |
|
|
|
|
| 199 |
return np.array(embeddings), total_tokens
|
| 200 |
|
| 201 |
def encode_queries(self, text: str):
|
| 202 |
+
# Using the internal tokenizer to encode the texts and get the total
|
| 203 |
+
# number of tokens
|
| 204 |
encoding = self._model.model.tokenizer.encode(text)
|
| 205 |
embedding = next(self._model.query_embed(text)).tolist()
|
| 206 |
|
|
|
|
| 223 |
model=self.model_name)
|
| 224 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 225 |
|
| 226 |
+
|
| 227 |
+
class QAnythingEmbed(Base):
|
| 228 |
+
_client = None
|
| 229 |
+
|
| 230 |
+
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
| 231 |
+
from BCEmbedding import EmbeddingModel as qanthing
|
| 232 |
+
if not QAnythingEmbed._client:
|
| 233 |
+
try:
|
| 234 |
+
print("LOADING BCE...")
|
| 235 |
+
QAnythingEmbed._client = qanthing(model_name_or_path=os.path.join(
|
| 236 |
+
get_project_base_directory(),
|
| 237 |
+
"rag/res/bce-embedding-base_v1"))
|
| 238 |
+
except Exception as e:
|
| 239 |
+
QAnythingEmbed._client = qanthing(
|
| 240 |
+
model_name_or_path=model_name.replace(
|
| 241 |
+
"maidalun1020", "InfiniFlow"))
|
| 242 |
+
|
| 243 |
+
def encode(self, texts: list, batch_size=10):
|
| 244 |
+
res = []
|
| 245 |
+
token_count = 0
|
| 246 |
+
for t in texts:
|
| 247 |
+
token_count += num_tokens_from_string(t)
|
| 248 |
+
for i in range(0, len(texts), batch_size):
|
| 249 |
+
embds = QAnythingEmbed._client.encode(texts[i:i + batch_size])
|
| 250 |
+
res.extend(embds)
|
| 251 |
+
return np.array(res), token_count
|
| 252 |
+
|
| 253 |
+
def encode_queries(self, text):
|
| 254 |
+
embds = QAnythingEmbed._client.encode([text])
|
| 255 |
+
return np.array(embds[0]), num_tokens_from_string(text)
|
rag/nlp/search.py
CHANGED
|
@@ -46,7 +46,7 @@ class Dealer:
|
|
| 46 |
"k": topk,
|
| 47 |
"similarity": sim,
|
| 48 |
"num_candidates": topk * 2,
|
| 49 |
-
"query_vector":
|
| 50 |
}
|
| 51 |
|
| 52 |
def search(self, req, idxnm, emb_mdl=None):
|
|
|
|
| 46 |
"k": topk,
|
| 47 |
"similarity": sim,
|
| 48 |
"num_candidates": topk * 2,
|
| 49 |
+
"query_vector": [float(v) for v in qv]
|
| 50 |
}
|
| 51 |
|
| 52 |
def search(self, req, idxnm, emb_mdl=None):
|
rag/svr/task_executor.py
CHANGED
|
@@ -244,8 +244,9 @@ def main(comm, mod):
|
|
| 244 |
for _, r in rows.iterrows():
|
| 245 |
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
|
| 246 |
try:
|
| 247 |
-
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
|
| 248 |
except Exception as e:
|
|
|
|
| 249 |
callback(prog=-1, msg=str(e))
|
| 250 |
continue
|
| 251 |
|
|
|
|
| 244 |
for _, r in rows.iterrows():
|
| 245 |
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
|
| 246 |
try:
|
| 247 |
+
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING, llm_name=r["embd_id"], lang=r["language"])
|
| 248 |
except Exception as e:
|
| 249 |
+
traceback.print_stack(e)
|
| 250 |
callback(prog=-1, msg=str(e))
|
| 251 |
continue
|
| 252 |
|
requirements.txt
CHANGED
|
@@ -132,3 +132,5 @@ xpinyin==0.7.6
|
|
| 132 |
xxhash==3.4.1
|
| 133 |
yarl==1.9.4
|
| 134 |
zhipuai==2.0.1
|
|
|
|
|
|
|
|
|
| 132 |
xxhash==3.4.1
|
| 133 |
yarl==1.9.4
|
| 134 |
zhipuai==2.0.1
|
| 135 |
+
BCEmbedding
|
| 136 |
+
loguru==0.7.2
|