File size: 6,899 Bytes
3079197 484e5ab 3079197 3198faf e6acaf6 9bf75d4 e6acaf6 9bf75d4 3079197 51482f3 3079197 51482f3 3079197 004756c 3079197 6224edc 79ada0b 3198faf e6acaf6 3198faf 3079197 3198faf 6224edc 3198faf 79ada0b 3198faf 79ada0b 3198faf 6be3dd5 79ada0b 6be3dd5 79ada0b 8887e47 79ada0b b085dec 79ada0b b085dec 79ada0b b085dec 79ada0b b085dec 79ada0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from peewee import Expression
from api.db import TenantPermission, FileType, TaskStatus
from api.db.db_models import DB, Knowledgebase, Tenant
from api.db.db_models import Document
from api.db.services.common_service import CommonService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db import StatusEnum
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords):
if keywords:
docs = cls.model.select().where(
cls.model.kb_id == kb_id,
cls.model.name.like(f"%%{keywords}%%"))
else:
docs = cls.model.select().where(cls.model.kb_id == kb_id)
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def insert(cls, doc):
if not cls.save(**doc):
raise RuntimeError("Database error (Document)!")
e, doc = cls.get_by_id(doc["id"])
if not e:
raise RuntimeError("Database error (Document retrieval)!")
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num + 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return doc
@classmethod
@DB.connection_context()
def delete(cls, doc):
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
if not KnowledgebaseService.update_by_id(
kb.id, {"doc_num": kb.doc_num - 1}):
raise RuntimeError("Database error (Knowledgebase)!")
return cls.delete_by_id(doc.id)
@classmethod
@DB.connection_context()
def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
fields = [
cls.model.id,
cls.model.kb_id,
cls.model.parser_id,
cls.model.parser_config,
cls.model.name,
cls.model.type,
cls.model.location,
cls.model.size,
Knowledgebase.tenant_id,
Tenant.embd_id,
Tenant.img2txt_id,
Tenant.asr_id,
cls.model.update_time]
docs = cls.model.select(*fields) \
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress == 0,
cls.model.update_time >= tm,
cls.model.run == TaskStatus.RUNNING.value,
(Expression(cls.model.create_time, "%%", comm) == mod))\
.order_by(cls.model.update_time.asc())\
.paginate(1, items_per_page)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def get_unfinished_docs(cls):
fields = [cls.model.id, cls.model.process_begin_at]
docs = cls.model.select(*fields) \
.where(
cls.model.status == StatusEnum.VALID.value,
~(cls.model.type == FileType.VIRTUAL.value),
cls.model.progress < 1,
cls.model.progress > 0)
return list(docs.dicts())
@classmethod
@DB.connection_context()
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
num = cls.model.update(token_num=cls.model.token_num + token_num,
chunk_num=cls.model.chunk_num + chunk_num,
process_duation=cls.model.process_duation + duation).where(
cls.model.id == doc_id).execute()
if num == 0:
raise LookupError(
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@classmethod
@DB.connection_context()
def get_tenant_id(cls, doc_id):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
return
return docs[0]["tenant_id"]
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, d = cls.get_by_id(id)
if not e:
raise LookupError(f"Document({id}) not found.")
def dfs_update(old, new):
for k, v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else:
old[k] = v
dfs_update(d.parser_config, config)
cls.update_by_id(id, {"parser_config": d.parser_config})
@classmethod
@DB.connection_context()
def get_doc_count(cls, tenant_id):
docs = cls.model.select(cls.model.id).join(Knowledgebase,
on=(Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.tenant_id == tenant_id)
return len(docs)
|