KevinHuSh
commited on
Commit
·
3198faf
1
Parent(s):
3079197
add alot of api (#23)
Browse files* clean rust version project
* clean rust version project
* build python version rag-flow
* add alot of api
- rag/llm/embedding_model.py +1 -1
- rag/nlp/huchunk.py +6 -3
- rag/nlp/search.py +1 -1
- rag/svr/parse_user_docs.py +35 -19
- rag/utils/__init__.py +19 -0
- rag/utils/es_conn.py +1 -0
- web_server/apps/document_app.py +47 -2
- web_server/apps/kb_app.py +14 -2
- web_server/apps/llm_app.py +95 -0
- web_server/apps/user_app.py +33 -5
- web_server/db/db_models.py +7 -4
- web_server/db/services/document_service.py +27 -13
- web_server/db/services/kb_service.py +33 -6
- web_server/db/services/llm_service.py +18 -0
- web_server/db/services/user_service.py +1 -1
- web_server/utils/file_utils.py +1 -1
rag/llm/embedding_model.py
CHANGED
|
@@ -35,7 +35,7 @@ class Base(ABC):
|
|
| 35 |
|
| 36 |
|
| 37 |
class HuEmbedding(Base):
|
| 38 |
-
def __init__(self):
|
| 39 |
"""
|
| 40 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
| 41 |
|
|
|
|
| 35 |
|
| 36 |
|
| 37 |
class HuEmbedding(Base):
|
| 38 |
+
def __init__(self, key="", model_name=""):
|
| 39 |
"""
|
| 40 |
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
| 41 |
|
rag/nlp/huchunk.py
CHANGED
|
@@ -411,9 +411,12 @@ class TextChunker(HuChunker):
|
|
| 411 |
flds = self.Fields()
|
| 412 |
if self.is_binary_file(fnm):
|
| 413 |
return flds
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
| 417 |
flds.table_chunks = []
|
| 418 |
return flds
|
| 419 |
|
|
|
|
| 411 |
flds = self.Fields()
|
| 412 |
if self.is_binary_file(fnm):
|
| 413 |
return flds
|
| 414 |
+
txt = ""
|
| 415 |
+
if isinstance(fnm, str):
|
| 416 |
+
with open(fnm, "r") as f:
|
| 417 |
+
txt = f.read()
|
| 418 |
+
else: txt = fnm.decode("utf-8")
|
| 419 |
+
flds.text_chunks = [(c, None) for c in self.naive_text_chunk(txt)]
|
| 420 |
flds.table_chunks = []
|
| 421 |
return flds
|
| 422 |
|
rag/nlp/search.py
CHANGED
|
@@ -8,7 +8,7 @@ from rag.nlp import huqie, query
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
|
| 11 |
-
def index_name(uid): return f"
|
| 12 |
|
| 13 |
|
| 14 |
class Dealer:
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
|
| 10 |
|
| 11 |
+
def index_name(uid): return f"ragflow_{uid}"
|
| 12 |
|
| 13 |
|
| 14 |
class Dealer:
|
rag/svr/parse_user_docs.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
|
|
|
| 17 |
import os
|
| 18 |
import hashlib
|
| 19 |
import copy
|
|
@@ -24,9 +25,10 @@ from timeit import default_timer as timer
|
|
| 24 |
|
| 25 |
from rag.llm import EmbeddingModel, CvModel
|
| 26 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 27 |
-
from rag.utils import ELASTICSEARCH
|
| 28 |
from rag.utils import MINIO
|
| 29 |
-
from rag.utils import rmSpace,
|
|
|
|
| 30 |
from rag.nlp import huchunk, huqie, search
|
| 31 |
from io import BytesIO
|
| 32 |
import pandas as pd
|
|
@@ -47,6 +49,7 @@ from rag.nlp.huchunk import (
|
|
| 47 |
from web_server.db import LLMType
|
| 48 |
from web_server.db.services.document_service import DocumentService
|
| 49 |
from web_server.db.services.llm_service import TenantLLMService
|
|
|
|
| 50 |
from web_server.utils import get_format_time
|
| 51 |
from web_server.utils.file_utils import get_project_base_directory
|
| 52 |
|
|
@@ -83,7 +86,7 @@ def collect(comm, mod, tm):
|
|
| 83 |
if len(docs) == 0:
|
| 84 |
return pd.DataFrame()
|
| 85 |
docs = pd.DataFrame(docs)
|
| 86 |
-
mtm =
|
| 87 |
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
| 88 |
return docs
|
| 89 |
|
|
@@ -99,11 +102,12 @@ def set_progress(docid, prog, msg="Processing...", begin=False):
|
|
| 99 |
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
| 100 |
|
| 101 |
|
| 102 |
-
def build(row):
|
| 103 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
| 104 |
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
| 105 |
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
| 106 |
return []
|
|
|
|
| 107 |
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
| 108 |
if ELASTICSEARCH.getTotal(res) > 0:
|
| 109 |
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
|
@@ -120,7 +124,8 @@ def build(row):
|
|
| 120 |
set_progress(row["id"], random.randint(0, 20) /
|
| 121 |
100., "Finished preparing! Start to slice file!", True)
|
| 122 |
try:
|
| 123 |
-
|
|
|
|
| 124 |
except Exception as e:
|
| 125 |
if re.search("(No such file|not found)", str(e)):
|
| 126 |
set_progress(
|
|
@@ -131,6 +136,9 @@ def build(row):
|
|
| 131 |
row["id"], -1, f"Internal server error: %s" %
|
| 132 |
str(e).replace(
|
| 133 |
"'", ""))
|
|
|
|
|
|
|
|
|
|
| 134 |
return []
|
| 135 |
|
| 136 |
if not obj.text_chunks and not obj.table_chunks:
|
|
@@ -144,7 +152,7 @@ def build(row):
|
|
| 144 |
"Finished slicing files. Start to embedding the content.")
|
| 145 |
|
| 146 |
doc = {
|
| 147 |
-
"doc_id": row["
|
| 148 |
"kb_id": [str(row["kb_id"])],
|
| 149 |
"docnm_kwd": os.path.split(row["location"])[-1],
|
| 150 |
"title_tks": huqie.qie(row["name"]),
|
|
@@ -164,10 +172,10 @@ def build(row):
|
|
| 164 |
docs.append(d)
|
| 165 |
continue
|
| 166 |
|
| 167 |
-
if isinstance(img,
|
| 168 |
-
img.save(output_buffer, format='JPEG')
|
| 169 |
-
else:
|
| 170 |
output_buffer = BytesIO(img)
|
|
|
|
|
|
|
| 171 |
|
| 172 |
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
| 173 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
|
@@ -215,15 +223,16 @@ def embedding(docs, mdl):
|
|
| 215 |
|
| 216 |
|
| 217 |
def model_instance(tenant_id, llm_type):
|
| 218 |
-
model_config = TenantLLMService.
|
| 219 |
-
if not model_config:
|
| 220 |
-
|
|
|
|
| 221 |
if llm_type == LLMType.EMBEDDING:
|
| 222 |
-
if model_config
|
| 223 |
-
return EmbeddingModel[model_config
|
| 224 |
if llm_type == LLMType.IMAGE2TEXT:
|
| 225 |
-
if model_config
|
| 226 |
-
return CvModel[model_config.llm_factory](model_config
|
| 227 |
|
| 228 |
|
| 229 |
def main(comm, mod):
|
|
@@ -231,7 +240,7 @@ def main(comm, mod):
|
|
| 231 |
from rag.llm import HuEmbedding
|
| 232 |
model = HuEmbedding()
|
| 233 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
| 234 |
-
tm =
|
| 235 |
rows = collect(comm, mod, tm)
|
| 236 |
if len(rows) == 0:
|
| 237 |
return
|
|
@@ -247,7 +256,7 @@ def main(comm, mod):
|
|
| 247 |
st_tm = timer()
|
| 248 |
cks = build(r, cv_mdl)
|
| 249 |
if not cks:
|
| 250 |
-
tmf.write(str(r["
|
| 251 |
continue
|
| 252 |
# TODO: exception handler
|
| 253 |
## set_progress(r["did"], -1, "ERROR: ")
|
|
@@ -268,12 +277,19 @@ def main(comm, mod):
|
|
| 268 |
cron_logger.error(str(es_r))
|
| 269 |
else:
|
| 270 |
set_progress(r["id"], 1., "Done!")
|
| 271 |
-
DocumentService.
|
|
|
|
|
|
|
| 272 |
tmf.write(str(r["update_time"]) + "\n")
|
| 273 |
tmf.close()
|
| 274 |
|
| 275 |
|
| 276 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
from mpi4py import MPI
|
| 278 |
comm = MPI.COMM_WORLD
|
| 279 |
main(comm.Get_size(), comm.Get_rank())
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import json
|
| 17 |
+
import logging
|
| 18 |
import os
|
| 19 |
import hashlib
|
| 20 |
import copy
|
|
|
|
| 25 |
|
| 26 |
from rag.llm import EmbeddingModel, CvModel
|
| 27 |
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
|
| 28 |
+
from rag.utils import ELASTICSEARCH
|
| 29 |
from rag.utils import MINIO
|
| 30 |
+
from rag.utils import rmSpace, findMaxTm
|
| 31 |
+
|
| 32 |
from rag.nlp import huchunk, huqie, search
|
| 33 |
from io import BytesIO
|
| 34 |
import pandas as pd
|
|
|
|
| 49 |
from web_server.db import LLMType
|
| 50 |
from web_server.db.services.document_service import DocumentService
|
| 51 |
from web_server.db.services.llm_service import TenantLLMService
|
| 52 |
+
from web_server.settings import database_logger
|
| 53 |
from web_server.utils import get_format_time
|
| 54 |
from web_server.utils.file_utils import get_project_base_directory
|
| 55 |
|
|
|
|
| 86 |
if len(docs) == 0:
|
| 87 |
return pd.DataFrame()
|
| 88 |
docs = pd.DataFrame(docs)
|
| 89 |
+
mtm = docs["update_time"].max()
|
| 90 |
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
|
| 91 |
return docs
|
| 92 |
|
|
|
|
| 102 |
cron_logger.error("set_progress:({}), {}".format(docid, str(e)))
|
| 103 |
|
| 104 |
|
| 105 |
+
def build(row, cvmdl):
|
| 106 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
| 107 |
set_progress(row["id"], -1, "File size exceeds( <= %dMb )" %
|
| 108 |
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
| 109 |
return []
|
| 110 |
+
|
| 111 |
res = ELASTICSEARCH.search(Q("term", doc_id=row["id"]))
|
| 112 |
if ELASTICSEARCH.getTotal(res) > 0:
|
| 113 |
ELASTICSEARCH.updateScriptByQuery(Q("term", doc_id=row["id"]),
|
|
|
|
| 124 |
set_progress(row["id"], random.randint(0, 20) /
|
| 125 |
100., "Finished preparing! Start to slice file!", True)
|
| 126 |
try:
|
| 127 |
+
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
|
| 128 |
+
obj = chuck_doc(row["name"], MINIO.get(row["kb_id"], row["location"]), cvmdl)
|
| 129 |
except Exception as e:
|
| 130 |
if re.search("(No such file|not found)", str(e)):
|
| 131 |
set_progress(
|
|
|
|
| 136 |
row["id"], -1, f"Internal server error: %s" %
|
| 137 |
str(e).replace(
|
| 138 |
"'", ""))
|
| 139 |
+
|
| 140 |
+
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
|
| 141 |
+
|
| 142 |
return []
|
| 143 |
|
| 144 |
if not obj.text_chunks and not obj.table_chunks:
|
|
|
|
| 152 |
"Finished slicing files. Start to embedding the content.")
|
| 153 |
|
| 154 |
doc = {
|
| 155 |
+
"doc_id": row["id"],
|
| 156 |
"kb_id": [str(row["kb_id"])],
|
| 157 |
"docnm_kwd": os.path.split(row["location"])[-1],
|
| 158 |
"title_tks": huqie.qie(row["name"]),
|
|
|
|
| 172 |
docs.append(d)
|
| 173 |
continue
|
| 174 |
|
| 175 |
+
if isinstance(img, bytes):
|
|
|
|
|
|
|
| 176 |
output_buffer = BytesIO(img)
|
| 177 |
+
else:
|
| 178 |
+
img.save(output_buffer, format='JPEG')
|
| 179 |
|
| 180 |
MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
|
| 181 |
d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
|
|
|
|
| 223 |
|
| 224 |
|
| 225 |
def model_instance(tenant_id, llm_type):
|
| 226 |
+
model_config = TenantLLMService.get_api_key(tenant_id, model_type=LLMType.EMBEDDING)
|
| 227 |
+
if not model_config:
|
| 228 |
+
model_config = {"llm_factory": "local", "api_key": "", "llm_name": ""}
|
| 229 |
+
else: model_config = model_config[0].to_dict()
|
| 230 |
if llm_type == LLMType.EMBEDDING:
|
| 231 |
+
if model_config["llm_factory"] not in EmbeddingModel: return
|
| 232 |
+
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"])
|
| 233 |
if llm_type == LLMType.IMAGE2TEXT:
|
| 234 |
+
if model_config["llm_factory"] not in CvModel: return
|
| 235 |
+
return CvModel[model_config.llm_factory](model_config["api_key"], model_config["llm_name"])
|
| 236 |
|
| 237 |
|
| 238 |
def main(comm, mod):
|
|
|
|
| 240 |
from rag.llm import HuEmbedding
|
| 241 |
model = HuEmbedding()
|
| 242 |
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
|
| 243 |
+
tm = findMaxTm(tm_fnm)
|
| 244 |
rows = collect(comm, mod, tm)
|
| 245 |
if len(rows) == 0:
|
| 246 |
return
|
|
|
|
| 256 |
st_tm = timer()
|
| 257 |
cks = build(r, cv_mdl)
|
| 258 |
if not cks:
|
| 259 |
+
tmf.write(str(r["update_time"]) + "\n")
|
| 260 |
continue
|
| 261 |
# TODO: exception handler
|
| 262 |
## set_progress(r["did"], -1, "ERROR: ")
|
|
|
|
| 277 |
cron_logger.error(str(es_r))
|
| 278 |
else:
|
| 279 |
set_progress(r["id"], 1., "Done!")
|
| 280 |
+
DocumentService.increment_chunk_num(r["id"], r["kb_id"], tk_count, len(cks), timer()-st_tm)
|
| 281 |
+
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
|
| 282 |
+
|
| 283 |
tmf.write(str(r["update_time"]) + "\n")
|
| 284 |
tmf.close()
|
| 285 |
|
| 286 |
|
| 287 |
if __name__ == "__main__":
|
| 288 |
+
peewee_logger = logging.getLogger('peewee')
|
| 289 |
+
peewee_logger.propagate = False
|
| 290 |
+
peewee_logger.addHandler(database_logger.handlers[0])
|
| 291 |
+
peewee_logger.setLevel(database_logger.level)
|
| 292 |
+
|
| 293 |
from mpi4py import MPI
|
| 294 |
comm = MPI.COMM_WORLD
|
| 295 |
main(comm.Get_size(), comm.Get_rank())
|
rag/utils/__init__.py
CHANGED
|
@@ -40,6 +40,25 @@ def findMaxDt(fnm):
|
|
| 40 |
print("WARNING: can't find " + fnm)
|
| 41 |
return m
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def num_tokens_from_string(string: str) -> int:
|
| 44 |
"""Returns the number of tokens in a text string."""
|
| 45 |
encoding = tiktoken.get_encoding('cl100k_base')
|
|
|
|
| 40 |
print("WARNING: can't find " + fnm)
|
| 41 |
return m
|
| 42 |
|
| 43 |
+
|
| 44 |
+
def findMaxTm(fnm):
|
| 45 |
+
m = 0
|
| 46 |
+
try:
|
| 47 |
+
with open(fnm, "r") as f:
|
| 48 |
+
while True:
|
| 49 |
+
l = f.readline()
|
| 50 |
+
if not l:
|
| 51 |
+
break
|
| 52 |
+
l = l.strip("\n")
|
| 53 |
+
if l == 'nan':
|
| 54 |
+
continue
|
| 55 |
+
if int(l) > m:
|
| 56 |
+
m = int(l)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print("WARNING: can't find " + fnm)
|
| 59 |
+
return m
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def num_tokens_from_string(string: str) -> int:
|
| 63 |
"""Returns the number of tokens in a text string."""
|
| 64 |
encoding = tiktoken.get_encoding('cl100k_base')
|
rag/utils/es_conn.py
CHANGED
|
@@ -294,6 +294,7 @@ class HuEs:
|
|
| 294 |
except Exception as e:
|
| 295 |
es_logger.error("ES updateByQuery deleteByQuery: " +
|
| 296 |
str(e) + "【Q】:" + str(query.to_dict()))
|
|
|
|
| 297 |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
| 298 |
continue
|
| 299 |
|
|
|
|
| 294 |
except Exception as e:
|
| 295 |
es_logger.error("ES updateByQuery deleteByQuery: " +
|
| 296 |
str(e) + "【Q】:" + str(query.to_dict()))
|
| 297 |
+
if str(e).find("NotFoundError") > 0: return True
|
| 298 |
if str(e).find("Timeout") > 0 or str(e).find("Conflict") > 0:
|
| 299 |
continue
|
| 300 |
|
web_server/apps/document_app.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
| 16 |
import pathlib
|
| 17 |
|
| 18 |
from elasticsearch_dsl import Q
|
|
@@ -195,11 +196,15 @@ def rm():
|
|
| 195 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 196 |
if not e:
|
| 197 |
return get_data_error_result(retmsg="Document not found!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
| 199 |
return get_data_error_result(
|
| 200 |
retmsg="Database error (Document removal)!")
|
| 201 |
-
|
| 202 |
-
MINIO.rm(
|
| 203 |
return get_json_result(data=True)
|
| 204 |
except Exception as e:
|
| 205 |
return server_error_response(e)
|
|
@@ -233,3 +238,43 @@ def rename():
|
|
| 233 |
return get_json_result(data=True)
|
| 234 |
except Exception as e:
|
| 235 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
import base64
|
| 17 |
import pathlib
|
| 18 |
|
| 19 |
from elasticsearch_dsl import Q
|
|
|
|
| 196 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 197 |
if not e:
|
| 198 |
return get_data_error_result(retmsg="Document not found!")
|
| 199 |
+
if not ELASTICSEARCH.deleteByQuery(Q("match", doc_id=doc.id), idxnm=search.index_name(doc.kb_id)):
|
| 200 |
+
return get_json_result(data=False, retmsg='Remove from ES failure"', retcode=RetCode.SERVER_ERROR)
|
| 201 |
+
|
| 202 |
+
DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, 0)
|
| 203 |
if not DocumentService.delete_by_id(req["doc_id"]):
|
| 204 |
return get_data_error_result(
|
| 205 |
retmsg="Database error (Document removal)!")
|
| 206 |
+
|
| 207 |
+
MINIO.rm(doc.kb_id, doc.location)
|
| 208 |
return get_json_result(data=True)
|
| 209 |
except Exception as e:
|
| 210 |
return server_error_response(e)
|
|
|
|
| 238 |
return get_json_result(data=True)
|
| 239 |
except Exception as e:
|
| 240 |
return server_error_response(e)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@manager.route('/get', methods=['GET'])
|
| 244 |
+
@login_required
|
| 245 |
+
def get():
|
| 246 |
+
doc_id = request.args["doc_id"]
|
| 247 |
+
try:
|
| 248 |
+
e, doc = DocumentService.get_by_id(doc_id)
|
| 249 |
+
if not e:
|
| 250 |
+
return get_data_error_result(retmsg="Document not found!")
|
| 251 |
+
|
| 252 |
+
blob = MINIO.get(doc.kb_id, doc.location)
|
| 253 |
+
return get_json_result(data={"base64": base64.b64decode(blob)})
|
| 254 |
+
except Exception as e:
|
| 255 |
+
return server_error_response(e)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
@manager.route('/change_parser', methods=['POST'])
|
| 259 |
+
@login_required
|
| 260 |
+
@validate_request("doc_id", "parser_id")
|
| 261 |
+
def change_parser():
|
| 262 |
+
req = request.json
|
| 263 |
+
try:
|
| 264 |
+
e, doc = DocumentService.get_by_id(req["doc_id"])
|
| 265 |
+
if not e:
|
| 266 |
+
return get_data_error_result(retmsg="Document not found!")
|
| 267 |
+
if doc.parser_id.lower() == req["parser_id"].lower():
|
| 268 |
+
return get_json_result(data=True)
|
| 269 |
+
|
| 270 |
+
e = DocumentService.update_by_id(doc.id, {"parser_id": req["parser_id"], "progress":0, "progress_msg": ""})
|
| 271 |
+
if not e:
|
| 272 |
+
return get_data_error_result(retmsg="Document not found!")
|
| 273 |
+
e = DocumentService.increment_chunk_num(doc.id, doc.kb_id, doc.token_num*-1, doc.chunk_num*-1, doc.process_duation*-1)
|
| 274 |
+
if not e:
|
| 275 |
+
return get_data_error_result(retmsg="Document not found!")
|
| 276 |
+
|
| 277 |
+
return get_json_result(data=True)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
return server_error_response(e)
|
| 280 |
+
|
web_server/apps/kb_app.py
CHANGED
|
@@ -29,7 +29,7 @@ from web_server.utils.api_utils import get_json_result
|
|
| 29 |
|
| 30 |
@manager.route('/create', methods=['post'])
|
| 31 |
@login_required
|
| 32 |
-
@validate_request("name", "description", "permission", "
|
| 33 |
def create():
|
| 34 |
req = request.json
|
| 35 |
req["name"] = req["name"].strip()
|
|
@@ -46,7 +46,7 @@ def create():
|
|
| 46 |
|
| 47 |
@manager.route('/update', methods=['post'])
|
| 48 |
@login_required
|
| 49 |
-
@validate_request("kb_id", "name", "description", "permission", "
|
| 50 |
def update():
|
| 51 |
req = request.json
|
| 52 |
req["name"] = req["name"].strip()
|
|
@@ -72,6 +72,18 @@ def update():
|
|
| 72 |
return server_error_response(e)
|
| 73 |
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
@manager.route('/list', methods=['GET'])
|
| 76 |
@login_required
|
| 77 |
def list():
|
|
|
|
| 29 |
|
| 30 |
@manager.route('/create', methods=['post'])
|
| 31 |
@login_required
|
| 32 |
+
@validate_request("name", "description", "permission", "parser_id")
|
| 33 |
def create():
|
| 34 |
req = request.json
|
| 35 |
req["name"] = req["name"].strip()
|
|
|
|
| 46 |
|
| 47 |
@manager.route('/update', methods=['post'])
|
| 48 |
@login_required
|
| 49 |
+
@validate_request("kb_id", "name", "description", "permission", "parser_id")
|
| 50 |
def update():
|
| 51 |
req = request.json
|
| 52 |
req["name"] = req["name"].strip()
|
|
|
|
| 72 |
return server_error_response(e)
|
| 73 |
|
| 74 |
|
| 75 |
+
@manager.route('/detail', methods=['GET'])
|
| 76 |
+
@login_required
|
| 77 |
+
def detail():
|
| 78 |
+
kb_id = request.args["kb_id"]
|
| 79 |
+
try:
|
| 80 |
+
kb = KnowledgebaseService.get_detail(kb_id)
|
| 81 |
+
if not kb: return get_data_error_result(retmsg="Can't find this knowledgebase!")
|
| 82 |
+
return get_json_result(data=kb)
|
| 83 |
+
except Exception as e:
|
| 84 |
+
return server_error_response(e)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
@manager.route('/list', methods=['GET'])
|
| 88 |
@login_required
|
| 89 |
def list():
|
web_server/apps/llm_app.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright 2019 The FATE Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
from flask import request
|
| 17 |
+
from flask_login import login_required, current_user
|
| 18 |
+
|
| 19 |
+
from web_server.db.services import duplicate_name
|
| 20 |
+
from web_server.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
| 21 |
+
from web_server.db.services.user_service import TenantService, UserTenantService
|
| 22 |
+
from web_server.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 23 |
+
from web_server.utils import get_uuid, get_format_time
|
| 24 |
+
from web_server.db import StatusEnum, UserTenantRole
|
| 25 |
+
from web_server.db.services.kb_service import KnowledgebaseService
|
| 26 |
+
from web_server.db.db_models import Knowledgebase, TenantLLM
|
| 27 |
+
from web_server.settings import stat_logger, RetCode
|
| 28 |
+
from web_server.utils.api_utils import get_json_result
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@manager.route('/factories', methods=['GET'])
|
| 32 |
+
@login_required
|
| 33 |
+
def factories():
|
| 34 |
+
try:
|
| 35 |
+
fac = LLMFactoriesService.get_all()
|
| 36 |
+
return get_json_result(data=fac.to_json())
|
| 37 |
+
except Exception as e:
|
| 38 |
+
return server_error_response(e)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@manager.route('/set_api_key', methods=['POST'])
|
| 42 |
+
@login_required
|
| 43 |
+
@validate_request("llm_factory", "api_key")
|
| 44 |
+
def set_api_key():
|
| 45 |
+
req = request.json
|
| 46 |
+
llm = {
|
| 47 |
+
"tenant_id": current_user.id,
|
| 48 |
+
"llm_factory": req["llm_factory"],
|
| 49 |
+
"api_key": req["api_key"]
|
| 50 |
+
}
|
| 51 |
+
# TODO: Test api_key
|
| 52 |
+
for n in ["model_type", "llm_name"]:
|
| 53 |
+
if n in req: llm[n] = req[n]
|
| 54 |
+
|
| 55 |
+
TenantLLM.insert(**llm).on_conflict("replace").execute()
|
| 56 |
+
return get_json_result(data=True)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@manager.route('/my_llms', methods=['GET'])
|
| 60 |
+
@login_required
|
| 61 |
+
def my_llms():
|
| 62 |
+
try:
|
| 63 |
+
objs = TenantLLMService.query(tenant_id=current_user.id)
|
| 64 |
+
objs = [o.to_dict() for o in objs]
|
| 65 |
+
for o in objs: del o["api_key"]
|
| 66 |
+
return get_json_result(data=objs)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
return server_error_response(e)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@manager.route('/list', methods=['GET'])
|
| 72 |
+
@login_required
|
| 73 |
+
def list():
|
| 74 |
+
try:
|
| 75 |
+
objs = TenantLLMService.query(tenant_id=current_user.id)
|
| 76 |
+
objs = [o.to_dict() for o in objs if o.api_key]
|
| 77 |
+
fct = {}
|
| 78 |
+
for o in objs:
|
| 79 |
+
if o["llm_factory"] not in fct: fct[o["llm_factory"]] = []
|
| 80 |
+
if o["llm_name"]: fct[o["llm_factory"]].append(o["llm_name"])
|
| 81 |
+
|
| 82 |
+
llms = LLMService.get_all()
|
| 83 |
+
llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
|
| 84 |
+
for m in llms:
|
| 85 |
+
m["available"] = False
|
| 86 |
+
if m["fid"] in fct and (not fct[m["fid"]] or m["llm_name"] in fct[m["fid"]]):
|
| 87 |
+
m["available"] = True
|
| 88 |
+
res = {}
|
| 89 |
+
for m in llms:
|
| 90 |
+
if m["fid"] not in res: res[m["fid"]] = []
|
| 91 |
+
res[m["fid"]].append(m)
|
| 92 |
+
|
| 93 |
+
return get_json_result(data=res)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
return server_error_response(e)
|
web_server/apps/user_app.py
CHANGED
|
@@ -16,9 +16,12 @@
|
|
| 16 |
from flask import request, session, redirect, url_for
|
| 17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 18 |
from flask_login import login_required, current_user, login_user, logout_user
|
|
|
|
|
|
|
|
|
|
| 19 |
from web_server.utils.api_utils import server_error_response, validate_request
|
| 20 |
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
| 21 |
-
from web_server.db import UserTenantRole
|
| 22 |
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
| 23 |
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
| 24 |
from web_server.settings import stat_logger
|
|
@@ -47,8 +50,9 @@ def login():
|
|
| 47 |
avatar = download_img(userinfo["avatar_url"])
|
| 48 |
except Exception as e:
|
| 49 |
stat_logger.exception(e)
|
|
|
|
| 50 |
try:
|
| 51 |
-
users = user_register({
|
| 52 |
"access_token": session["access_token"],
|
| 53 |
"email": userinfo["email"],
|
| 54 |
"avatar": avatar,
|
|
@@ -63,6 +67,7 @@ def login():
|
|
| 63 |
login_user(user)
|
| 64 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
| 65 |
except Exception as e:
|
|
|
|
| 66 |
stat_logger.exception(e)
|
| 67 |
return server_error_response(e)
|
| 68 |
elif not request.json:
|
|
@@ -162,7 +167,25 @@ def user_info():
|
|
| 162 |
return get_json_result(data=current_user.to_dict())
|
| 163 |
|
| 164 |
|
| 165 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
user_id = get_uuid()
|
| 167 |
user["id"] = user_id
|
| 168 |
tenant = {
|
|
@@ -180,10 +203,12 @@ def user_register(user):
|
|
| 180 |
"invited_by": user_id,
|
| 181 |
"role": UserTenantRole.OWNER
|
| 182 |
}
|
|
|
|
| 183 |
|
| 184 |
if not UserService.save(**user):return
|
| 185 |
TenantService.save(**tenant)
|
| 186 |
UserTenantService.save(**usr_tenant)
|
|
|
|
| 187 |
return UserService.query(email=user["email"])
|
| 188 |
|
| 189 |
|
|
@@ -203,14 +228,17 @@ def user_add():
|
|
| 203 |
"last_login_time": get_format_time(),
|
| 204 |
"is_superuser": False,
|
| 205 |
}
|
|
|
|
|
|
|
| 206 |
try:
|
| 207 |
-
users = user_register(user_dict)
|
| 208 |
if not users: raise Exception('Register user failure.')
|
| 209 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 210 |
user = users[0]
|
| 211 |
login_user(user)
|
| 212 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
| 213 |
except Exception as e:
|
|
|
|
| 214 |
stat_logger.exception(e)
|
| 215 |
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
| 216 |
|
|
@@ -220,7 +248,7 @@ def user_add():
|
|
| 220 |
@login_required
|
| 221 |
def tenant_info():
|
| 222 |
try:
|
| 223 |
-
tenants = TenantService.get_by_user_id(current_user.id)
|
| 224 |
return get_json_result(data=tenants)
|
| 225 |
except Exception as e:
|
| 226 |
return server_error_response(e)
|
|
|
|
| 16 |
from flask import request, session, redirect, url_for
|
| 17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 18 |
from flask_login import login_required, current_user, login_user, logout_user
|
| 19 |
+
|
| 20 |
+
from web_server.db.db_models import TenantLLM
|
| 21 |
+
from web_server.db.services.llm_service import TenantLLMService
|
| 22 |
from web_server.utils.api_utils import server_error_response, validate_request
|
| 23 |
from web_server.utils import get_uuid, get_format_time, decrypt, download_img
|
| 24 |
+
from web_server.db import UserTenantRole, LLMType
|
| 25 |
from web_server.settings import RetCode, GITHUB_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS
|
| 26 |
from web_server.db.services.user_service import UserService, TenantService, UserTenantService
|
| 27 |
from web_server.settings import stat_logger
|
|
|
|
| 50 |
avatar = download_img(userinfo["avatar_url"])
|
| 51 |
except Exception as e:
|
| 52 |
stat_logger.exception(e)
|
| 53 |
+
user_id = get_uuid()
|
| 54 |
try:
|
| 55 |
+
users = user_register(user_id, {
|
| 56 |
"access_token": session["access_token"],
|
| 57 |
"email": userinfo["email"],
|
| 58 |
"avatar": avatar,
|
|
|
|
| 67 |
login_user(user)
|
| 68 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome back!")
|
| 69 |
except Exception as e:
|
| 70 |
+
rollback_user_registration(user_id)
|
| 71 |
stat_logger.exception(e)
|
| 72 |
return server_error_response(e)
|
| 73 |
elif not request.json:
|
|
|
|
| 167 |
return get_json_result(data=current_user.to_dict())
|
| 168 |
|
| 169 |
|
| 170 |
+
def rollback_user_registration(user_id):
|
| 171 |
+
try:
|
| 172 |
+
TenantService.delete_by_id(user_id)
|
| 173 |
+
except Exception as e:
|
| 174 |
+
pass
|
| 175 |
+
try:
|
| 176 |
+
u = UserTenantService.query(tenant_id=user_id)
|
| 177 |
+
if u:
|
| 178 |
+
UserTenantService.delete_by_id(u[0].id)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
pass
|
| 181 |
+
try:
|
| 182 |
+
TenantLLM.delete().where(TenantLLM.tenant_id==user_id).excute()
|
| 183 |
+
except Exception as e:
|
| 184 |
+
pass
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def user_register(user_id, user):
|
| 188 |
+
|
| 189 |
user_id = get_uuid()
|
| 190 |
user["id"] = user_id
|
| 191 |
tenant = {
|
|
|
|
| 203 |
"invited_by": user_id,
|
| 204 |
"role": UserTenantRole.OWNER
|
| 205 |
}
|
| 206 |
+
tenant_llm = {"tenant_id": user_id, "llm_factory": "OpenAI", "api_key": "infiniflow API Key"}
|
| 207 |
|
| 208 |
if not UserService.save(**user):return
|
| 209 |
TenantService.save(**tenant)
|
| 210 |
UserTenantService.save(**usr_tenant)
|
| 211 |
+
TenantLLMService.save(**tenant_llm)
|
| 212 |
return UserService.query(email=user["email"])
|
| 213 |
|
| 214 |
|
|
|
|
| 228 |
"last_login_time": get_format_time(),
|
| 229 |
"is_superuser": False,
|
| 230 |
}
|
| 231 |
+
|
| 232 |
+
user_id = get_uuid()
|
| 233 |
try:
|
| 234 |
+
users = user_register(user_id, user_dict)
|
| 235 |
if not users: raise Exception('Register user failure.')
|
| 236 |
if len(users) > 1: raise Exception('Same E-mail exist!')
|
| 237 |
user = users[0]
|
| 238 |
login_user(user)
|
| 239 |
return cors_reponse(data=user.to_json(), auth=user.get_id(), retmsg="Welcome aboard!")
|
| 240 |
except Exception as e:
|
| 241 |
+
rollback_user_registration(user_id)
|
| 242 |
stat_logger.exception(e)
|
| 243 |
return get_json_result(data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR)
|
| 244 |
|
|
|
|
| 248 |
@login_required
|
| 249 |
def tenant_info():
|
| 250 |
try:
|
| 251 |
+
tenants = TenantService.get_by_user_id(current_user.id)[0]
|
| 252 |
return get_json_result(data=tenants)
|
| 253 |
except Exception as e:
|
| 254 |
return server_error_response(e)
|
web_server/db/db_models.py
CHANGED
|
@@ -428,6 +428,7 @@ class LLMFactories(DataBaseModel):
|
|
| 428 |
class LLM(DataBaseModel):
|
| 429 |
# defautlt LLMs for every users
|
| 430 |
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
|
|
|
| 431 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
| 432 |
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
| 433 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
|
@@ -442,8 +443,8 @@ class LLM(DataBaseModel):
|
|
| 442 |
class TenantLLM(DataBaseModel):
|
| 443 |
tenant_id = CharField(max_length=32, null=False)
|
| 444 |
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
| 445 |
-
model_type = CharField(max_length=128, null=
|
| 446 |
-
llm_name = CharField(max_length=128, null=
|
| 447 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
| 448 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
| 449 |
|
|
@@ -452,7 +453,7 @@ class TenantLLM(DataBaseModel):
|
|
| 452 |
|
| 453 |
class Meta:
|
| 454 |
db_table = "tenant_llm"
|
| 455 |
-
primary_key = CompositeKey('tenant_id', 'llm_factory')
|
| 456 |
|
| 457 |
|
| 458 |
class Knowledgebase(DataBaseModel):
|
|
@@ -464,7 +465,9 @@ class Knowledgebase(DataBaseModel):
|
|
| 464 |
permission = CharField(max_length=16, null=False, help_text="me|team")
|
| 465 |
created_by = CharField(max_length=32, null=False)
|
| 466 |
doc_num = IntegerField(default=0)
|
| 467 |
-
|
|
|
|
|
|
|
| 468 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
| 469 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
| 470 |
|
|
|
|
| 428 |
class LLM(DataBaseModel):
|
| 429 |
# defautlt LLMs for every users
|
| 430 |
llm_name = CharField(max_length=128, null=False, help_text="LLM name", primary_key=True)
|
| 431 |
+
model_type = CharField(max_length=128, null=False, help_text="LLM, Text Embedding, Image2Text, ASR")
|
| 432 |
fid = CharField(max_length=128, null=False, help_text="LLM factory id")
|
| 433 |
tags = CharField(max_length=255, null=False, help_text="LLM, Text Embedding, Image2Text, Chat, 32k...")
|
| 434 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
|
|
|
| 443 |
class TenantLLM(DataBaseModel):
|
| 444 |
tenant_id = CharField(max_length=32, null=False)
|
| 445 |
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name")
|
| 446 |
+
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR")
|
| 447 |
+
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="")
|
| 448 |
api_key = CharField(max_length=255, null=True, help_text="API KEY")
|
| 449 |
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
| 450 |
|
|
|
|
| 453 |
|
| 454 |
class Meta:
|
| 455 |
db_table = "tenant_llm"
|
| 456 |
+
primary_key = CompositeKey('tenant_id', 'llm_factory', 'llm_name')
|
| 457 |
|
| 458 |
|
| 459 |
class Knowledgebase(DataBaseModel):
|
|
|
|
| 465 |
permission = CharField(max_length=16, null=False, help_text="me|team")
|
| 466 |
created_by = CharField(max_length=32, null=False)
|
| 467 |
doc_num = IntegerField(default=0)
|
| 468 |
+
token_num = IntegerField(default=0)
|
| 469 |
+
chunk_num = IntegerField(default=0)
|
| 470 |
+
|
| 471 |
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
| 472 |
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
| 473 |
|
web_server/db/services/document_service.py
CHANGED
|
@@ -13,12 +13,13 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
|
|
|
| 16 |
from web_server.db import TenantPermission, FileType
|
| 17 |
-
from web_server.db.db_models import DB, Knowledgebase
|
| 18 |
from web_server.db.db_models import Document
|
| 19 |
from web_server.db.services.common_service import CommonService
|
| 20 |
from web_server.db.services.kb_service import KnowledgebaseService
|
| 21 |
-
from web_server.utils import get_uuid, get_format_time
|
| 22 |
from web_server.db.db_utils import StatusEnum
|
| 23 |
|
| 24 |
|
|
@@ -61,15 +62,28 @@ class DocumentService(CommonService):
|
|
| 61 |
@classmethod
|
| 62 |
@DB.connection_context()
|
| 63 |
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
| 64 |
-
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, Knowledgebase.tenant_id]
|
| 65 |
-
docs = cls.model.select(fields)
|
| 66 |
-
cls.model.
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
| 75 |
return list(docs.dicts())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
from peewee import Expression
|
| 17 |
+
|
| 18 |
from web_server.db import TenantPermission, FileType
|
| 19 |
+
from web_server.db.db_models import DB, Knowledgebase, Tenant
|
| 20 |
from web_server.db.db_models import Document
|
| 21 |
from web_server.db.services.common_service import CommonService
|
| 22 |
from web_server.db.services.kb_service import KnowledgebaseService
|
|
|
|
| 23 |
from web_server.db.db_utils import StatusEnum
|
| 24 |
|
| 25 |
|
|
|
|
| 62 |
@classmethod
|
| 63 |
@DB.connection_context()
|
| 64 |
def get_newly_uploaded(cls, tm, mod, comm, items_per_page=64):
|
| 65 |
+
fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, cls.model.update_time]
|
| 66 |
+
docs = cls.model.select(*fields) \
|
| 67 |
+
.join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
|
| 68 |
+
.join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
|
| 69 |
+
.where(
|
| 70 |
+
cls.model.status == StatusEnum.VALID.value,
|
| 71 |
+
~(cls.model.type == FileType.VIRTUAL.value),
|
| 72 |
+
cls.model.progress == 0,
|
| 73 |
+
cls.model.update_time >= tm,
|
| 74 |
+
(Expression(cls.model.create_time, "%%", comm) == mod))\
|
| 75 |
+
.order_by(cls.model.update_time.asc())\
|
| 76 |
+
.paginate(1, items_per_page)
|
| 77 |
return list(docs.dicts())
|
| 78 |
+
|
| 79 |
+
@classmethod
|
| 80 |
+
@DB.connection_context()
|
| 81 |
+
def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duation):
|
| 82 |
+
num = cls.model.update(token_num=cls.model.token_num + token_num,
|
| 83 |
+
chunk_num=cls.model.chunk_num + chunk_num,
|
| 84 |
+
process_duation=cls.model.process_duation+duation).where(
|
| 85 |
+
cls.model.id == doc_id).execute()
|
| 86 |
+
if num == 0:raise LookupError("Document not found which is supposed to be there")
|
| 87 |
+
num = Knowledgebase.update(token_num=Knowledgebase.token_num+token_num, chunk_num=Knowledgebase.chunk_num+chunk_num).where(Knowledgebase.id==kb_id).execute()
|
| 88 |
+
return num
|
| 89 |
+
|
web_server/db/services/kb_service.py
CHANGED
|
@@ -17,7 +17,7 @@ import peewee
|
|
| 17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 18 |
|
| 19 |
from web_server.db import TenantPermission
|
| 20 |
-
from web_server.db.db_models import DB, UserTenant
|
| 21 |
from web_server.db.db_models import Knowledgebase
|
| 22 |
from web_server.db.services.common_service import CommonService
|
| 23 |
from web_server.utils import get_uuid, get_format_time
|
|
@@ -29,15 +29,42 @@ class KnowledgebaseService(CommonService):
|
|
| 29 |
|
| 30 |
@classmethod
|
| 31 |
@DB.connection_context()
|
| 32 |
-
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
|
|
|
| 33 |
kbs = cls.model.select().where(
|
| 34 |
-
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
| 35 |
-
|
|
|
|
| 36 |
)
|
| 37 |
-
if desc:
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
kbs = kbs.paginate(page_number, items_per_page)
|
| 41 |
|
| 42 |
return list(kbs.dicts())
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from werkzeug.security import generate_password_hash, check_password_hash
|
| 18 |
|
| 19 |
from web_server.db import TenantPermission
|
| 20 |
+
from web_server.db.db_models import DB, UserTenant, Tenant
|
| 21 |
from web_server.db.db_models import Knowledgebase
|
| 22 |
from web_server.db.services.common_service import CommonService
|
| 23 |
from web_server.utils import get_uuid, get_format_time
|
|
|
|
| 29 |
|
| 30 |
@classmethod
|
| 31 |
@DB.connection_context()
|
| 32 |
+
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
| 33 |
+
page_number, items_per_page, orderby, desc):
|
| 34 |
kbs = cls.model.select().where(
|
| 35 |
+
((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission ==
|
| 36 |
+
TenantPermission.TEAM.value)) | (cls.model.tenant_id == user_id))
|
| 37 |
+
& (cls.model.status == StatusEnum.VALID.value)
|
| 38 |
)
|
| 39 |
+
if desc:
|
| 40 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
| 41 |
+
else:
|
| 42 |
+
kbs = kbs.order_by(cls.model.getter_by(orderby).asc())
|
| 43 |
|
| 44 |
kbs = kbs.paginate(page_number, items_per_page)
|
| 45 |
|
| 46 |
return list(kbs.dicts())
|
| 47 |
|
| 48 |
+
@classmethod
|
| 49 |
+
@DB.connection_context()
|
| 50 |
+
def get_detail(cls, kb_id):
|
| 51 |
+
fields = [
|
| 52 |
+
cls.model.id,
|
| 53 |
+
Tenant.embd_id,
|
| 54 |
+
cls.model.avatar,
|
| 55 |
+
cls.model.name,
|
| 56 |
+
cls.model.description,
|
| 57 |
+
cls.model.permission,
|
| 58 |
+
cls.model.doc_num,
|
| 59 |
+
cls.model.token_num,
|
| 60 |
+
cls.model.chunk_num,
|
| 61 |
+
cls.model.parser_id]
|
| 62 |
+
kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
|
| 63 |
+
(cls.model.id == kb_id),
|
| 64 |
+
(cls.model.status == StatusEnum.VALID.value)
|
| 65 |
+
)
|
| 66 |
+
if not kbs:
|
| 67 |
+
return
|
| 68 |
+
d = kbs[0].to_dict()
|
| 69 |
+
d["embd_id"] = kbs[0].tenant.embd_id
|
| 70 |
+
return d
|
web_server/db/services/llm_service.py
CHANGED
|
@@ -33,3 +33,21 @@ class LLMService(CommonService):
|
|
| 33 |
|
| 34 |
class TenantLLMService(CommonService):
|
| 35 |
model = TenantLLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
class TenantLLMService(CommonService):
|
| 35 |
model = TenantLLM
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
@DB.connection_context()
|
| 39 |
+
def get_api_key(cls, tenant_id, model_type):
|
| 40 |
+
objs = cls.query(tenant_id=tenant_id, model_type=model_type)
|
| 41 |
+
if objs and len(objs)>0 and objs[0].llm_name:
|
| 42 |
+
return objs[0]
|
| 43 |
+
|
| 44 |
+
fields = [LLM.llm_name, cls.model.llm_factory, cls.model.api_key]
|
| 45 |
+
objs = cls.model.select(*fields).join(LLM, on=(LLM.fid == cls.model.llm_factory)).where(
|
| 46 |
+
(cls.model.tenant_id == tenant_id),
|
| 47 |
+
(cls.model.model_type == model_type),
|
| 48 |
+
(LLM.status == StatusEnum.VALID)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if not objs:return
|
| 52 |
+
return objs[0]
|
| 53 |
+
|
web_server/db/services/user_service.py
CHANGED
|
@@ -79,7 +79,7 @@ class TenantService(CommonService):
|
|
| 79 |
@classmethod
|
| 80 |
@DB.connection_context()
|
| 81 |
def get_by_user_id(cls, user_id):
|
| 82 |
-
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, UserTenant.role]
|
| 83 |
return list(cls.model.select(*fields)\
|
| 84 |
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
| 85 |
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
|
|
|
| 79 |
@classmethod
|
| 80 |
@DB.connection_context()
|
| 81 |
def get_by_user_id(cls, user_id):
|
| 82 |
+
fields = [cls.model.id.alias("tenant_id"), cls.model.name, cls.model.llm_id, cls.model.embd_id, cls.model.asr_id, cls.model.img2txt_id, cls.model.parser_ids, UserTenant.role]
|
| 83 |
return list(cls.model.select(*fields)\
|
| 84 |
.join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value)))\
|
| 85 |
.where(cls.model.status == StatusEnum.VALID.value).dicts())
|
web_server/utils/file_utils.py
CHANGED
|
@@ -143,7 +143,7 @@ def filename_type(filename):
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
-
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|
|
|
|
| 143 |
if re.match(r".*\.pdf$", filename):
|
| 144 |
return FileType.PDF.value
|
| 145 |
|
| 146 |
+
if re.match(r".*\.(doc|ppt|yml|xml|htm|json|csv|txt|ini|xsl|wps|rtf|hlp|pages|numbers|key|md)$", filename):
|
| 147 |
return FileType.DOC.value
|
| 148 |
|
| 149 |
if re.match(r".*\.(wav|flac|ape|alac|wavpack|wv|mp3|aac|ogg|vorbis|opus|mp3)$", filename):
|