KevinHuSh
commited on
Commit
·
3069c36
1
Parent(s):
1b1c88a
Refactor (#537)
Browse files### What problem does this PR solve?
### Type of change
- [x] Refactoring
- api/apps/document_app.py +2 -1
- api/apps/llm_app.py +2 -2
- api/db/db_models.py +1 -1
- api/db/init_data.py +5 -3
- api/db/services/llm_service.py +1 -1
- api/db/services/task_service.py +20 -0
- deepdoc/parser/pdf_parser.py +17 -17
- rag/llm/__init__.py +1 -1
- rag/llm/embedding_model.py +6 -6
- rag/svr/cache_file_svr.py +43 -0
- rag/svr/task_broker.py +1 -1
- rag/svr/task_executor.py +6 -0
- rag/utils/minio_conn.py +0 -1
- rag/utils/redis_conn.py +19 -0
api/apps/document_app.py
CHANGED
|
@@ -58,7 +58,8 @@ def upload():
|
|
| 58 |
if not e:
|
| 59 |
return get_data_error_result(
|
| 60 |
retmsg="Can't find this knowledgebase!")
|
| 61 |
-
|
|
|
|
| 62 |
return get_data_error_result(
|
| 63 |
retmsg="Exceed the maximum file number of a free user!")
|
| 64 |
|
|
|
|
| 58 |
if not e:
|
| 59 |
return get_data_error_result(
|
| 60 |
retmsg="Can't find this knowledgebase!")
|
| 61 |
+
MAX_FILE_NUM_PER_USER = int(os.environ.get('MAX_FILE_NUM_PER_USER', 0))
|
| 62 |
+
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
|
| 63 |
return get_data_error_result(
|
| 64 |
retmsg="Exceed the maximum file number of a free user!")
|
| 65 |
|
api/apps/llm_app.py
CHANGED
|
@@ -28,7 +28,7 @@ from rag.llm import EmbeddingModel, ChatModel
|
|
| 28 |
def factories():
|
| 29 |
try:
|
| 30 |
fac = LLMFactoriesService.get_all()
|
| 31 |
-
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["
|
| 32 |
except Exception as e:
|
| 33 |
return server_error_response(e)
|
| 34 |
|
|
@@ -174,7 +174,7 @@ def list():
|
|
| 174 |
llms = [m.to_dict()
|
| 175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
| 176 |
for m in llms:
|
| 177 |
-
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["
|
| 178 |
|
| 179 |
llm_set = set([m["llm_name"] for m in llms])
|
| 180 |
for o in objs:
|
|
|
|
| 28 |
def factories():
|
| 29 |
try:
|
| 30 |
fac = LLMFactoriesService.get_all()
|
| 31 |
+
return get_json_result(data=[f.to_dict() for f in fac if f.name not in ["Youdao", "FastEmbed"]])
|
| 32 |
except Exception as e:
|
| 33 |
return server_error_response(e)
|
| 34 |
|
|
|
|
| 174 |
llms = [m.to_dict()
|
| 175 |
for m in llms if m.status == StatusEnum.VALID.value]
|
| 176 |
for m in llms:
|
| 177 |
+
m["available"] = m["fid"] in facts or m["llm_name"].lower() == "flag-embedding" or m["fid"] in ["Youdao","FastEmbed"]
|
| 178 |
|
| 179 |
llm_set = set([m["llm_name"] for m in llms])
|
| 180 |
for o in objs:
|
api/db/db_models.py
CHANGED
|
@@ -697,7 +697,7 @@ class Dialog(DataBaseModel):
|
|
| 697 |
null=True,
|
| 698 |
default="Chinese",
|
| 699 |
help_text="English|Chinese")
|
| 700 |
-
llm_id = CharField(max_length=
|
| 701 |
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
| 702 |
"presence_penalty": 0.4, "max_tokens": 215})
|
| 703 |
prompt_type = CharField(
|
|
|
|
| 697 |
null=True,
|
| 698 |
default="Chinese",
|
| 699 |
help_text="English|Chinese")
|
| 700 |
+
llm_id = CharField(max_length=128, null=False, help_text="default llm ID")
|
| 701 |
llm_setting = JSONField(null=False, default={"temperature": 0.1, "top_p": 0.3, "frequency_penalty": 0.7,
|
| 702 |
"presence_penalty": 0.4, "max_tokens": 215})
|
| 703 |
prompt_type = CharField(
|
api/db/init_data.py
CHANGED
|
@@ -120,7 +120,7 @@ factory_infos = [{
|
|
| 120 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 121 |
"status": "1",
|
| 122 |
},{
|
| 123 |
-
"name": "
|
| 124 |
"logo": "",
|
| 125 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 126 |
"status": "1",
|
|
@@ -323,7 +323,7 @@ def init_llm_factory():
|
|
| 323 |
"max_tokens": 2147483648,
|
| 324 |
"model_type": LLMType.EMBEDDING.value
|
| 325 |
},
|
| 326 |
-
# ------------------------
|
| 327 |
{
|
| 328 |
"fid": factory_infos[7]["name"],
|
| 329 |
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
|
@@ -347,7 +347,9 @@ def init_llm_factory():
|
|
| 347 |
LLMService.filter_delete([LLM.fid == "Local"])
|
| 348 |
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
| 349 |
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
| 350 |
-
|
|
|
|
|
|
|
| 351 |
"""
|
| 352 |
drop table llm;
|
| 353 |
drop table llm_factories;
|
|
|
|
| 120 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 121 |
"status": "1",
|
| 122 |
},{
|
| 123 |
+
"name": "Youdao",
|
| 124 |
"logo": "",
|
| 125 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 126 |
"status": "1",
|
|
|
|
| 323 |
"max_tokens": 2147483648,
|
| 324 |
"model_type": LLMType.EMBEDDING.value
|
| 325 |
},
|
| 326 |
+
# ------------------------ Youdao -----------------------
|
| 327 |
{
|
| 328 |
"fid": factory_infos[7]["name"],
|
| 329 |
"llm_name": "maidalun1020/bce-embedding-base_v1",
|
|
|
|
| 347 |
LLMService.filter_delete([LLM.fid == "Local"])
|
| 348 |
LLMService.filter_delete([LLM.fid == "Moonshot", LLM.llm_name == "flag-embedding"])
|
| 349 |
TenantLLMService.filter_delete([TenantLLM.llm_factory == "Moonshot", TenantLLM.llm_name == "flag-embedding"])
|
| 350 |
+
LLMFactoriesService.filter_update([LLMFactoriesService.model.name == "QAnything"], {"name": "Youdao"})
|
| 351 |
+
LLMService.filter_update([LLMService.model.fid == "QAnything"], {"fid": "Youdao"})
|
| 352 |
+
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
| 353 |
"""
|
| 354 |
drop table llm;
|
| 355 |
drop table llm_factories;
|
api/db/services/llm_service.py
CHANGED
|
@@ -81,7 +81,7 @@ class TenantLLMService(CommonService):
|
|
| 81 |
if not model_config:
|
| 82 |
if llm_type == LLMType.EMBEDDING.value:
|
| 83 |
llm = LLMService.query(llm_name=llm_name)
|
| 84 |
-
if llm and llm[0].fid in ["
|
| 85 |
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
| 86 |
if not model_config:
|
| 87 |
if llm_name == "flag-embedding":
|
|
|
|
| 81 |
if not model_config:
|
| 82 |
if llm_type == LLMType.EMBEDDING.value:
|
| 83 |
llm = LLMService.query(llm_name=llm_name)
|
| 84 |
+
if llm and llm[0].fid in ["Youdao", "FastEmbed"]:
|
| 85 |
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
| 86 |
if not model_config:
|
| 87 |
if llm_name == "flag-embedding":
|
api/db/services/task_service.py
CHANGED
|
@@ -21,6 +21,7 @@ from api.db import StatusEnum, FileType, TaskStatus
|
|
| 21 |
from api.db.db_models import Task, Document, Knowledgebase, Tenant
|
| 22 |
from api.db.services.common_service import CommonService
|
| 23 |
from api.db.services.document_service import DocumentService
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class TaskService(CommonService):
|
|
@@ -70,6 +71,25 @@ class TaskService(CommonService):
|
|
| 70 |
cls.model.id == docs[0]["id"]).execute()
|
| 71 |
return docs
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
@classmethod
|
| 74 |
@DB.connection_context()
|
| 75 |
def do_cancel(cls, id):
|
|
|
|
| 21 |
from api.db.db_models import Task, Document, Knowledgebase, Tenant
|
| 22 |
from api.db.services.common_service import CommonService
|
| 23 |
from api.db.services.document_service import DocumentService
|
| 24 |
+
from api.utils import current_timestamp
|
| 25 |
|
| 26 |
|
| 27 |
class TaskService(CommonService):
|
|
|
|
| 71 |
cls.model.id == docs[0]["id"]).execute()
|
| 72 |
return docs
|
| 73 |
|
| 74 |
+
@classmethod
|
| 75 |
+
@DB.connection_context()
|
| 76 |
+
def get_ongoing_doc_name(cls):
|
| 77 |
+
with DB.lock("get_task", -1):
|
| 78 |
+
docs = cls.model.select(*[Document.kb_id, Document.location]) \
|
| 79 |
+
.join(Document, on=(cls.model.doc_id == Document.id)) \
|
| 80 |
+
.where(
|
| 81 |
+
Document.status == StatusEnum.VALID.value,
|
| 82 |
+
Document.run == TaskStatus.RUNNING.value,
|
| 83 |
+
~(Document.type == FileType.VIRTUAL.value),
|
| 84 |
+
cls.model.progress >= 0,
|
| 85 |
+
cls.model.progress < 1,
|
| 86 |
+
cls.model.create_time >= current_timestamp() - 180000
|
| 87 |
+
)
|
| 88 |
+
docs = list(docs.dicts())
|
| 89 |
+
if not docs: return []
|
| 90 |
+
|
| 91 |
+
return list(set([(d["kb_id"], d["location"]) for d in docs]))
|
| 92 |
+
|
| 93 |
@classmethod
|
| 94 |
@DB.connection_context()
|
| 95 |
def do_cancel(cls, id):
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -37,8 +37,8 @@ class HuParser:
|
|
| 37 |
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
| 38 |
try:
|
| 39 |
model_dir = os.path.join(
|
| 40 |
-
|
| 41 |
-
|
| 42 |
self.updown_cnt_mdl.load_model(os.path.join(
|
| 43 |
model_dir, "updown_concat_xgb.model"))
|
| 44 |
except Exception as e:
|
|
@@ -49,7 +49,6 @@ class HuParser:
|
|
| 49 |
self.updown_cnt_mdl.load_model(os.path.join(
|
| 50 |
model_dir, "updown_concat_xgb.model"))
|
| 51 |
|
| 52 |
-
|
| 53 |
self.page_from = 0
|
| 54 |
"""
|
| 55 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
@@ -76,7 +75,7 @@ class HuParser:
|
|
| 76 |
def _y_dis(
|
| 77 |
self, a, b):
|
| 78 |
return (
|
| 79 |
-
|
| 80 |
|
| 81 |
def _match_proj(self, b):
|
| 82 |
proj_patt = [
|
|
@@ -99,9 +98,9 @@ class HuParser:
|
|
| 99 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
| 100 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
| 101 |
tks_all = up["text"][-LEN:].strip() \
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
tks_all = huqie.qie(tks_all).split(" ")
|
| 106 |
fea = [
|
| 107 |
up.get("R", -1) == down.get("R", -1),
|
|
@@ -123,7 +122,7 @@ class HuParser:
|
|
| 123 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
| 124 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
| 125 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
| 126 |
-
|
| 127 |
self._match_proj(down),
|
| 128 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
| 129 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
|
@@ -185,7 +184,7 @@ class HuParser:
|
|
| 185 |
continue
|
| 186 |
for tb in tbls: # for table
|
| 187 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
| 188 |
-
|
| 189 |
left *= ZM
|
| 190 |
top *= ZM
|
| 191 |
right *= ZM
|
|
@@ -297,7 +296,7 @@ class HuParser:
|
|
| 297 |
for b in bxs:
|
| 298 |
if not b["text"]:
|
| 299 |
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
| 300 |
-
|
| 301 |
b["text"] = self.ocr.recognize(np.array(img),
|
| 302 |
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
| 303 |
dtype=np.float32))
|
|
@@ -622,7 +621,7 @@ class HuParser:
|
|
| 622 |
i += 1
|
| 623 |
continue
|
| 624 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
| 625 |
-
|
| 626 |
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
| 627 |
"title",
|
| 628 |
"figure caption",
|
|
@@ -975,6 +974,7 @@ class HuParser:
|
|
| 975 |
self.outlines.append((a["/Title"], depth))
|
| 976 |
continue
|
| 977 |
dfs(a, depth + 1)
|
|
|
|
| 978 |
dfs(outlines, 0)
|
| 979 |
except Exception as e:
|
| 980 |
logging.warning(f"Outlines exception: {e}")
|
|
@@ -984,7 +984,7 @@ class HuParser:
|
|
| 984 |
logging.info("Images converted.")
|
| 985 |
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
| 986 |
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
| 987 |
-
|
| 988 |
if sum([1 if e else 0 for e in self.is_english]) > len(
|
| 989 |
self.page_images) / 2:
|
| 990 |
self.is_english = True
|
|
@@ -1012,9 +1012,9 @@ class HuParser:
|
|
| 1012 |
j += 1
|
| 1013 |
|
| 1014 |
self.__ocr(i + 1, img, chars, zoomin)
|
| 1015 |
-
|
| 1016 |
-
|
| 1017 |
-
#print("OCR:", timer()-st)
|
| 1018 |
|
| 1019 |
if not self.is_english and not any(
|
| 1020 |
[c for c in self.page_chars]) and self.boxes:
|
|
@@ -1050,7 +1050,7 @@ class HuParser:
|
|
| 1050 |
left, right, top, bottom = float(left), float(
|
| 1051 |
right), float(top), float(bottom)
|
| 1052 |
poss.append(([int(p) - 1 for p in pn.split("-")],
|
| 1053 |
-
|
| 1054 |
if not poss:
|
| 1055 |
if need_position:
|
| 1056 |
return None, None
|
|
@@ -1076,7 +1076,7 @@ class HuParser:
|
|
| 1076 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
| 1077 |
right *
|
| 1078 |
ZM, min(
|
| 1079 |
-
|
| 1080 |
))
|
| 1081 |
)
|
| 1082 |
if 0 < ii < len(poss) - 1:
|
|
|
|
| 37 |
self.updown_cnt_mdl.set_param({"device": "cuda"})
|
| 38 |
try:
|
| 39 |
model_dir = os.path.join(
|
| 40 |
+
get_project_base_directory(),
|
| 41 |
+
"rag/res/deepdoc")
|
| 42 |
self.updown_cnt_mdl.load_model(os.path.join(
|
| 43 |
model_dir, "updown_concat_xgb.model"))
|
| 44 |
except Exception as e:
|
|
|
|
| 49 |
self.updown_cnt_mdl.load_model(os.path.join(
|
| 50 |
model_dir, "updown_concat_xgb.model"))
|
| 51 |
|
|
|
|
| 52 |
self.page_from = 0
|
| 53 |
"""
|
| 54 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
|
|
|
| 75 |
def _y_dis(
|
| 76 |
self, a, b):
|
| 77 |
return (
|
| 78 |
+
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
|
| 79 |
|
| 80 |
def _match_proj(self, b):
|
| 81 |
proj_patt = [
|
|
|
|
| 98 |
tks_down = huqie.qie(down["text"][:LEN]).split(" ")
|
| 99 |
tks_up = huqie.qie(up["text"][-LEN:]).split(" ")
|
| 100 |
tks_all = up["text"][-LEN:].strip() \
|
| 101 |
+
+ (" " if re.match(r"[a-zA-Z0-9]+",
|
| 102 |
+
up["text"][-1] + down["text"][0]) else "") \
|
| 103 |
+
+ down["text"][:LEN].strip()
|
| 104 |
tks_all = huqie.qie(tks_all).split(" ")
|
| 105 |
fea = [
|
| 106 |
up.get("R", -1) == down.get("R", -1),
|
|
|
|
| 122 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
| 123 |
True if re.search(r"[,,][^。.]+$", up["text"]) else False,
|
| 124 |
True if re.search(r"[\((][^\))]+$", up["text"])
|
| 125 |
+
and re.search(r"[\))]", down["text"]) else False,
|
| 126 |
self._match_proj(down),
|
| 127 |
True if re.match(r"[A-Z]", down["text"]) else False,
|
| 128 |
True if re.match(r"[A-Z]", up["text"][-1]) else False,
|
|
|
|
| 184 |
continue
|
| 185 |
for tb in tbls: # for table
|
| 186 |
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
|
| 187 |
+
tb["x1"] + MARGIN, tb["bottom"] + MARGIN
|
| 188 |
left *= ZM
|
| 189 |
top *= ZM
|
| 190 |
right *= ZM
|
|
|
|
| 296 |
for b in bxs:
|
| 297 |
if not b["text"]:
|
| 298 |
left, right, top, bott = b["x0"] * ZM, b["x1"] * \
|
| 299 |
+
ZM, b["top"] * ZM, b["bottom"] * ZM
|
| 300 |
b["text"] = self.ocr.recognize(np.array(img),
|
| 301 |
np.array([[left, top], [right, top], [right, bott], [left, bott]],
|
| 302 |
dtype=np.float32))
|
|
|
|
| 621 |
i += 1
|
| 622 |
continue
|
| 623 |
lout_no = str(self.boxes[i]["page_number"]) + \
|
| 624 |
+
"-" + str(self.boxes[i]["layoutno"])
|
| 625 |
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
|
| 626 |
"title",
|
| 627 |
"figure caption",
|
|
|
|
| 974 |
self.outlines.append((a["/Title"], depth))
|
| 975 |
continue
|
| 976 |
dfs(a, depth + 1)
|
| 977 |
+
|
| 978 |
dfs(outlines, 0)
|
| 979 |
except Exception as e:
|
| 980 |
logging.warning(f"Outlines exception: {e}")
|
|
|
|
| 984 |
logging.info("Images converted.")
|
| 985 |
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
|
| 986 |
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
|
| 987 |
+
range(len(self.page_chars))]
|
| 988 |
if sum([1 if e else 0 for e in self.is_english]) > len(
|
| 989 |
self.page_images) / 2:
|
| 990 |
self.is_english = True
|
|
|
|
| 1012 |
j += 1
|
| 1013 |
|
| 1014 |
self.__ocr(i + 1, img, chars, zoomin)
|
| 1015 |
+
if callback and i % 6 == 5:
|
| 1016 |
+
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
|
| 1017 |
+
# print("OCR:", timer()-st)
|
| 1018 |
|
| 1019 |
if not self.is_english and not any(
|
| 1020 |
[c for c in self.page_chars]) and self.boxes:
|
|
|
|
| 1050 |
left, right, top, bottom = float(left), float(
|
| 1051 |
right), float(top), float(bottom)
|
| 1052 |
poss.append(([int(p) - 1 for p in pn.split("-")],
|
| 1053 |
+
left, right, top, bottom))
|
| 1054 |
if not poss:
|
| 1055 |
if need_position:
|
| 1056 |
return None, None
|
|
|
|
| 1076 |
self.page_images[pns[0]].crop((left * ZM, top * ZM,
|
| 1077 |
right *
|
| 1078 |
ZM, min(
|
| 1079 |
+
bottom, self.page_images[pns[0]].size[1])
|
| 1080 |
))
|
| 1081 |
)
|
| 1082 |
if 0 < ii < len(poss) - 1:
|
rag/llm/__init__.py
CHANGED
|
@@ -25,7 +25,7 @@ EmbeddingModel = {
|
|
| 25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
| 26 |
"ZHIPU-AI": ZhipuEmbed,
|
| 27 |
"FastEmbed": FastEmbed,
|
| 28 |
-
"
|
| 29 |
}
|
| 30 |
|
| 31 |
|
|
|
|
| 25 |
"Tongyi-Qianwen": HuEmbedding, #QWenEmbed,
|
| 26 |
"ZHIPU-AI": ZhipuEmbed,
|
| 27 |
"FastEmbed": FastEmbed,
|
| 28 |
+
"Youdao": YoudaoEmbed
|
| 29 |
}
|
| 30 |
|
| 31 |
|
rag/llm/embedding_model.py
CHANGED
|
@@ -229,19 +229,19 @@ class XinferenceEmbed(Base):
|
|
| 229 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 230 |
|
| 231 |
|
| 232 |
-
class
|
| 233 |
_client = None
|
| 234 |
|
| 235 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
| 236 |
from BCEmbedding import EmbeddingModel as qanthing
|
| 237 |
-
if not
|
| 238 |
try:
|
| 239 |
print("LOADING BCE...")
|
| 240 |
-
|
| 241 |
get_project_base_directory(),
|
| 242 |
"rag/res/bce-embedding-base_v1"))
|
| 243 |
except Exception as e:
|
| 244 |
-
|
| 245 |
model_name_or_path=model_name.replace(
|
| 246 |
"maidalun1020", "InfiniFlow"))
|
| 247 |
|
|
@@ -251,10 +251,10 @@ class QAnythingEmbed(Base):
|
|
| 251 |
for t in texts:
|
| 252 |
token_count += num_tokens_from_string(t)
|
| 253 |
for i in range(0, len(texts), batch_size):
|
| 254 |
-
embds =
|
| 255 |
res.extend(embds)
|
| 256 |
return np.array(res), token_count
|
| 257 |
|
| 258 |
def encode_queries(self, text):
|
| 259 |
-
embds =
|
| 260 |
return np.array(embds[0]), num_tokens_from_string(text)
|
|
|
|
| 229 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 230 |
|
| 231 |
|
| 232 |
+
class YoudaoEmbed(Base):
|
| 233 |
_client = None
|
| 234 |
|
| 235 |
def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs):
|
| 236 |
from BCEmbedding import EmbeddingModel as qanthing
|
| 237 |
+
if not YoudaoEmbed._client:
|
| 238 |
try:
|
| 239 |
print("LOADING BCE...")
|
| 240 |
+
YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join(
|
| 241 |
get_project_base_directory(),
|
| 242 |
"rag/res/bce-embedding-base_v1"))
|
| 243 |
except Exception as e:
|
| 244 |
+
YoudaoEmbed._client = qanthing(
|
| 245 |
model_name_or_path=model_name.replace(
|
| 246 |
"maidalun1020", "InfiniFlow"))
|
| 247 |
|
|
|
|
| 251 |
for t in texts:
|
| 252 |
token_count += num_tokens_from_string(t)
|
| 253 |
for i in range(0, len(texts), batch_size):
|
| 254 |
+
embds = YoudaoEmbed._client.encode(texts[i:i + batch_size])
|
| 255 |
res.extend(embds)
|
| 256 |
return np.array(res), token_count
|
| 257 |
|
| 258 |
def encode_queries(self, text):
|
| 259 |
+
embds = YoudaoEmbed._client.encode([text])
|
| 260 |
return np.array(embds[0]), num_tokens_from_string(text)
|
rag/svr/cache_file_svr.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import time
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
from api.db.db_models import close_connection
|
| 6 |
+
from api.db.services.task_service import TaskService
|
| 7 |
+
from rag.utils import MINIO
|
| 8 |
+
from rag.utils.redis_conn import REDIS_CONN
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def collect():
|
| 12 |
+
doc_locations = TaskService.get_ongoing_doc_name()
|
| 13 |
+
#print(tasks)
|
| 14 |
+
if len(doc_locations) == 0:
|
| 15 |
+
time.sleep(1)
|
| 16 |
+
return
|
| 17 |
+
return doc_locations
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
locations = collect()
|
| 21 |
+
if not locations:return
|
| 22 |
+
print("TASKS:", len(locations))
|
| 23 |
+
for kb_id, loc in locations:
|
| 24 |
+
try:
|
| 25 |
+
if REDIS_CONN.is_alive():
|
| 26 |
+
try:
|
| 27 |
+
key = "{}/{}".format(kb_id, loc)
|
| 28 |
+
if REDIS_CONN.exist(key):continue
|
| 29 |
+
file_bin = MINIO.get(kb_id, loc)
|
| 30 |
+
REDIS_CONN.transaction(key, file_bin, 12 * 60)
|
| 31 |
+
print("CACHE:", loc)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
traceback.print_stack(e)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
traceback.print_stack(e)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
while True:
|
| 41 |
+
main()
|
| 42 |
+
close_connection()
|
| 43 |
+
time.sleep(1)
|
rag/svr/task_broker.py
CHANGED
|
@@ -167,7 +167,7 @@ def update_progress():
|
|
| 167 |
info = {
|
| 168 |
"process_duation": datetime.timestamp(
|
| 169 |
datetime.now()) -
|
| 170 |
-
|
| 171 |
"run": status}
|
| 172 |
if prg != 0:
|
| 173 |
info["progress"] = prg
|
|
|
|
| 167 |
info = {
|
| 168 |
"process_duation": datetime.timestamp(
|
| 169 |
datetime.now()) -
|
| 170 |
+
d["process_begin_at"].timestamp(),
|
| 171 |
"run": status}
|
| 172 |
if prg != 0:
|
| 173 |
info["progress"] = prg
|
rag/svr/task_executor.py
CHANGED
|
@@ -107,8 +107,14 @@ def get_minio_binary(bucket, name):
|
|
| 107 |
global MINIO
|
| 108 |
if REDIS_CONN.is_alive():
|
| 109 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
r = REDIS_CONN.get("{}/{}".format(bucket, name))
|
| 111 |
if r: return r
|
|
|
|
| 112 |
except Exception as e:
|
| 113 |
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
|
| 114 |
return MINIO.get(bucket, name)
|
|
|
|
| 107 |
global MINIO
|
| 108 |
if REDIS_CONN.is_alive():
|
| 109 |
try:
|
| 110 |
+
for _ in range(30):
|
| 111 |
+
if REDIS_CONN.exist("{}/{}".format(bucket, name)):
|
| 112 |
+
time.sleep(1)
|
| 113 |
+
break
|
| 114 |
+
time.sleep(1)
|
| 115 |
r = REDIS_CONN.get("{}/{}".format(bucket, name))
|
| 116 |
if r: return r
|
| 117 |
+
cron_logger.warning("Cache missing: {}".format(name))
|
| 118 |
except Exception as e:
|
| 119 |
cron_logger.warning("Get redis[EXCEPTION]:" + str(e))
|
| 120 |
return MINIO.get(bucket, name)
|
rag/utils/minio_conn.py
CHANGED
|
@@ -56,7 +56,6 @@ class HuMinio(object):
|
|
| 56 |
except Exception as e:
|
| 57 |
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
| 58 |
|
| 59 |
-
|
| 60 |
def get(self, bucket, fnm):
|
| 61 |
for _ in range(1):
|
| 62 |
try:
|
|
|
|
| 56 |
except Exception as e:
|
| 57 |
minio_logger.error(f"Fail rm {bucket}/{fnm}: " + str(e))
|
| 58 |
|
|
|
|
| 59 |
def get(self, bucket, fnm):
|
| 60 |
for _ in range(1):
|
| 61 |
try:
|
rag/utils/redis_conn.py
CHANGED
|
@@ -25,6 +25,14 @@ class RedisDB:
|
|
| 25 |
def is_alive(self):
|
| 26 |
return self.REDIS is not None
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def get(self, k):
|
| 29 |
if not self.REDIS: return
|
| 30 |
try:
|
|
@@ -51,5 +59,16 @@ class RedisDB:
|
|
| 51 |
self.__open__()
|
| 52 |
return False
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
REDIS_CONN = RedisDB()
|
|
|
|
| 25 |
def is_alive(self):
|
| 26 |
return self.REDIS is not None
|
| 27 |
|
| 28 |
+
def exist(self, k):
|
| 29 |
+
if not self.REDIS: return
|
| 30 |
+
try:
|
| 31 |
+
return self.REDIS.exists(k)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logging.warning("[EXCEPTION]exist" + str(k) + "||" + str(e))
|
| 34 |
+
self.__open__()
|
| 35 |
+
|
| 36 |
def get(self, k):
|
| 37 |
if not self.REDIS: return
|
| 38 |
try:
|
|
|
|
| 59 |
self.__open__()
|
| 60 |
return False
|
| 61 |
|
| 62 |
+
def transaction(self, key, value, exp=3600):
|
| 63 |
+
try:
|
| 64 |
+
pipeline = self.REDIS.pipeline(transaction=True)
|
| 65 |
+
pipeline.set(key, value, exp, nx=True)
|
| 66 |
+
pipeline.execute()
|
| 67 |
+
return True
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.warning("[EXCEPTION]set" + str(key) + "||" + str(e))
|
| 70 |
+
self.__open__()
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
|
| 74 |
REDIS_CONN = RedisDB()
|