KevinHuSh
commited on
Commit
·
028fe40
1
Parent(s):
df17cda
add stream chat (#811)
Browse files### What problem does this PR solve?
#709
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/apps/api_app.py +45 -44
- api/apps/conversation_app.py +43 -7
- api/apps/system_app.py +67 -0
- api/db/services/dialog_service.py +51 -32
- api/db/services/document_service.py +2 -2
- api/db/services/llm_service.py +11 -1
- api/utils/api_utils.py +0 -4
- rag/llm/chat_model.py +107 -1
- rag/svr/task_executor.py +12 -13
- rag/utils/es_conn.py +3 -0
- rag/utils/minio_conn.py +10 -0
- rag/utils/redis_conn.py +4 -0
api/apps/api_app.py
CHANGED
|
@@ -13,10 +13,11 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
| 16 |
import os
|
| 17 |
import re
|
| 18 |
from datetime import datetime, timedelta
|
| 19 |
-
from flask import request
|
| 20 |
from flask_login import login_required, current_user
|
| 21 |
|
| 22 |
from api.db import FileType, ParserType
|
|
@@ -31,11 +32,11 @@ from api.settings import RetCode
|
|
| 31 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 32 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
| 33 |
from itsdangerous import URLSafeTimedSerializer
|
| 34 |
-
|
| 35 |
from api.utils.file_utils import filename_type, thumbnail
|
| 36 |
from rag.utils.minio_conn import MINIO
|
| 37 |
-
|
| 38 |
-
|
| 39 |
def generate_confirmation_token(tenent_id):
|
| 40 |
serializer = URLSafeTimedSerializer(tenent_id)
|
| 41 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
|
@@ -164,6 +165,7 @@ def completion():
|
|
| 164 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 165 |
if not e:
|
| 166 |
return get_data_error_result(retmsg="Conversation not found!")
|
|
|
|
| 167 |
|
| 168 |
msg = []
|
| 169 |
for m in req["messages"]:
|
|
@@ -180,13 +182,45 @@ def completion():
|
|
| 180 |
return get_data_error_result(retmsg="Dialog not found!")
|
| 181 |
del req["conversation_id"]
|
| 182 |
del req["messages"]
|
| 183 |
-
|
| 184 |
if not conv.reference:
|
| 185 |
conv.reference = []
|
| 186 |
-
conv.
|
| 187 |
-
conv.
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
except Exception as e:
|
| 191 |
return server_error_response(e)
|
| 192 |
|
|
@@ -229,7 +263,6 @@ def upload():
|
|
| 229 |
return get_json_result(
|
| 230 |
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
| 231 |
|
| 232 |
-
|
| 233 |
file = request.files['file']
|
| 234 |
if file.filename == '':
|
| 235 |
return get_json_result(
|
|
@@ -253,7 +286,6 @@ def upload():
|
|
| 253 |
location += "_"
|
| 254 |
blob = request.files['file'].read()
|
| 255 |
MINIO.put(kb_id, location, blob)
|
| 256 |
-
|
| 257 |
doc = {
|
| 258 |
"id": get_uuid(),
|
| 259 |
"kb_id": kb.id,
|
|
@@ -266,42 +298,11 @@ def upload():
|
|
| 266 |
"size": len(blob),
|
| 267 |
"thumbnail": thumbnail(filename, blob)
|
| 268 |
}
|
| 269 |
-
|
| 270 |
-
form_data=request.form
|
| 271 |
-
if "parser_id" in form_data.keys():
|
| 272 |
-
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
| 273 |
-
doc["parser_id"] = request.form.get("parser_id").strip()
|
| 274 |
if doc["type"] == FileType.VISUAL:
|
| 275 |
doc["parser_id"] = ParserType.PICTURE.value
|
| 276 |
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
| 277 |
doc["parser_id"] = ParserType.PRESENTATION.value
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
except Exception as e:
|
| 282 |
return server_error_response(e)
|
| 283 |
-
|
| 284 |
-
if "run" in form_data.keys():
|
| 285 |
-
if request.form.get("run").strip() == "1":
|
| 286 |
-
try:
|
| 287 |
-
info = {"run": 1, "progress": 0}
|
| 288 |
-
info["progress_msg"] = ""
|
| 289 |
-
info["chunk_num"] = 0
|
| 290 |
-
info["token_num"] = 0
|
| 291 |
-
DocumentService.update_by_id(doc["id"], info)
|
| 292 |
-
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
| 293 |
-
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
| 294 |
-
if not tenant_id:
|
| 295 |
-
return get_data_error_result(retmsg="Tenant not found!")
|
| 296 |
-
|
| 297 |
-
#e, doc = DocumentService.get_by_id(doc["id"])
|
| 298 |
-
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
| 299 |
-
e, doc = DocumentService.get_by_id(doc["id"])
|
| 300 |
-
doc = doc.to_dict()
|
| 301 |
-
doc["tenant_id"] = tenant_id
|
| 302 |
-
bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"])
|
| 303 |
-
queue_tasks(doc, bucket, name)
|
| 304 |
-
except Exception as e:
|
| 305 |
-
return server_error_response(e)
|
| 306 |
-
|
| 307 |
-
return get_json_result(data=doc_result.to_json())
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
import json
|
| 17 |
import os
|
| 18 |
import re
|
| 19 |
from datetime import datetime, timedelta
|
| 20 |
+
from flask import request, Response
|
| 21 |
from flask_login import login_required, current_user
|
| 22 |
|
| 23 |
from api.db import FileType, ParserType
|
|
|
|
| 32 |
from api.utils import get_uuid, current_timestamp, datetime_format
|
| 33 |
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request
|
| 34 |
from itsdangerous import URLSafeTimedSerializer
|
| 35 |
+
|
| 36 |
from api.utils.file_utils import filename_type, thumbnail
|
| 37 |
from rag.utils.minio_conn import MINIO
|
| 38 |
+
|
| 39 |
+
|
| 40 |
def generate_confirmation_token(tenent_id):
|
| 41 |
serializer = URLSafeTimedSerializer(tenent_id)
|
| 42 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
|
|
|
| 165 |
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
| 166 |
if not e:
|
| 167 |
return get_data_error_result(retmsg="Conversation not found!")
|
| 168 |
+
if "quote" not in req: req["quote"] = False
|
| 169 |
|
| 170 |
msg = []
|
| 171 |
for m in req["messages"]:
|
|
|
|
| 182 |
return get_data_error_result(retmsg="Dialog not found!")
|
| 183 |
del req["conversation_id"]
|
| 184 |
del req["messages"]
|
| 185 |
+
|
| 186 |
if not conv.reference:
|
| 187 |
conv.reference = []
|
| 188 |
+
conv.message.append({"role": "assistant", "content": ""})
|
| 189 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 190 |
+
|
| 191 |
+
def fillin_conv(ans):
|
| 192 |
+
nonlocal conv
|
| 193 |
+
if not conv.reference:
|
| 194 |
+
conv.reference.append(ans["reference"])
|
| 195 |
+
else: conv.reference[-1] = ans["reference"]
|
| 196 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
| 197 |
+
|
| 198 |
+
def stream():
|
| 199 |
+
nonlocal dia, msg, req, conv
|
| 200 |
+
try:
|
| 201 |
+
for ans in chat(dia, msg, True, **req):
|
| 202 |
+
fillin_conv(ans)
|
| 203 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 204 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 205 |
+
except Exception as e:
|
| 206 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
| 207 |
+
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
| 208 |
+
ensure_ascii=False) + "\n\n"
|
| 209 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
| 210 |
+
|
| 211 |
+
if req.get("stream", True):
|
| 212 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
| 213 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
| 214 |
+
resp.headers.add_header("Connection", "keep-alive")
|
| 215 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
| 216 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
| 217 |
+
return resp
|
| 218 |
+
else:
|
| 219 |
+
ans = chat(dia, msg, False, **req)
|
| 220 |
+
fillin_conv(ans)
|
| 221 |
+
API4ConversationService.append_message(conv.id, conv.to_dict())
|
| 222 |
+
return get_json_result(data=ans)
|
| 223 |
+
|
| 224 |
except Exception as e:
|
| 225 |
return server_error_response(e)
|
| 226 |
|
|
|
|
| 263 |
return get_json_result(
|
| 264 |
data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR)
|
| 265 |
|
|
|
|
| 266 |
file = request.files['file']
|
| 267 |
if file.filename == '':
|
| 268 |
return get_json_result(
|
|
|
|
| 286 |
location += "_"
|
| 287 |
blob = request.files['file'].read()
|
| 288 |
MINIO.put(kb_id, location, blob)
|
|
|
|
| 289 |
doc = {
|
| 290 |
"id": get_uuid(),
|
| 291 |
"kb_id": kb.id,
|
|
|
|
| 298 |
"size": len(blob),
|
| 299 |
"thumbnail": thumbnail(filename, blob)
|
| 300 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
if doc["type"] == FileType.VISUAL:
|
| 302 |
doc["parser_id"] = ParserType.PICTURE.value
|
| 303 |
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
| 304 |
doc["parser_id"] = ParserType.PRESENTATION.value
|
| 305 |
+
doc = DocumentService.insert(doc)
|
| 306 |
+
return get_json_result(data=doc.to_json())
|
|
|
|
| 307 |
except Exception as e:
|
| 308 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api/apps/conversation_app.py
CHANGED
|
@@ -13,12 +13,13 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
-
from flask import request
|
| 17 |
from flask_login import login_required
|
| 18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
| 19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 20 |
from api.utils import get_uuid
|
| 21 |
from api.utils.api_utils import get_json_result
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@manager.route('/set', methods=['POST'])
|
|
@@ -103,9 +104,12 @@ def list_convsersation():
|
|
| 103 |
|
| 104 |
@manager.route('/completion', methods=['POST'])
|
| 105 |
@login_required
|
| 106 |
-
|
| 107 |
def completion():
|
| 108 |
req = request.json
|
|
|
|
|
|
|
|
|
|
| 109 |
msg = []
|
| 110 |
for m in req["messages"]:
|
| 111 |
if m["role"] == "system":
|
|
@@ -123,13 +127,45 @@ def completion():
|
|
| 123 |
return get_data_error_result(retmsg="Dialog not found!")
|
| 124 |
del req["conversation_id"]
|
| 125 |
del req["messages"]
|
| 126 |
-
|
| 127 |
if not conv.reference:
|
| 128 |
conv.reference = []
|
| 129 |
-
conv.
|
| 130 |
-
conv.
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
except Exception as e:
|
| 134 |
return server_error_response(e)
|
| 135 |
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
from flask import request, Response, jsonify
|
| 17 |
from flask_login import login_required
|
| 18 |
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
| 19 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
| 20 |
from api.utils import get_uuid
|
| 21 |
from api.utils.api_utils import get_json_result
|
| 22 |
+
import json
|
| 23 |
|
| 24 |
|
| 25 |
@manager.route('/set', methods=['POST'])
|
|
|
|
| 104 |
|
| 105 |
@manager.route('/completion', methods=['POST'])
|
| 106 |
@login_required
|
| 107 |
+
#@validate_request("conversation_id", "messages")
|
| 108 |
def completion():
|
| 109 |
req = request.json
|
| 110 |
+
#req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
| 111 |
+
# {"role": "user", "content": "上海有吗?"}
|
| 112 |
+
#]}
|
| 113 |
msg = []
|
| 114 |
for m in req["messages"]:
|
| 115 |
if m["role"] == "system":
|
|
|
|
| 127 |
return get_data_error_result(retmsg="Dialog not found!")
|
| 128 |
del req["conversation_id"]
|
| 129 |
del req["messages"]
|
| 130 |
+
|
| 131 |
if not conv.reference:
|
| 132 |
conv.reference = []
|
| 133 |
+
conv.message.append({"role": "assistant", "content": ""})
|
| 134 |
+
conv.reference.append({"chunks": [], "doc_aggs": []})
|
| 135 |
+
|
| 136 |
+
def fillin_conv(ans):
|
| 137 |
+
nonlocal conv
|
| 138 |
+
if not conv.reference:
|
| 139 |
+
conv.reference.append(ans["reference"])
|
| 140 |
+
else: conv.reference[-1] = ans["reference"]
|
| 141 |
+
conv.message[-1] = {"role": "assistant", "content": ans["answer"]}
|
| 142 |
+
|
| 143 |
+
def stream():
|
| 144 |
+
nonlocal dia, msg, req, conv
|
| 145 |
+
try:
|
| 146 |
+
for ans in chat(dia, msg, True, **req):
|
| 147 |
+
fillin_conv(ans)
|
| 148 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
| 149 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 150 |
+
except Exception as e:
|
| 151 |
+
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
| 152 |
+
"data": {"answer": "**ERROR**: "+str(e), "reference": []}},
|
| 153 |
+
ensure_ascii=False) + "\n\n"
|
| 154 |
+
yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
|
| 155 |
+
|
| 156 |
+
if req.get("stream", True):
|
| 157 |
+
resp = Response(stream(), mimetype="text/event-stream")
|
| 158 |
+
resp.headers.add_header("Cache-control", "no-cache")
|
| 159 |
+
resp.headers.add_header("Connection", "keep-alive")
|
| 160 |
+
resp.headers.add_header("X-Accel-Buffering", "no")
|
| 161 |
+
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
| 162 |
+
return resp
|
| 163 |
+
|
| 164 |
+
else:
|
| 165 |
+
ans = chat(dia, msg, False, **req)
|
| 166 |
+
fillin_conv(ans)
|
| 167 |
+
ConversationService.update_by_id(conv.id, conv.to_dict())
|
| 168 |
+
return get_json_result(data=ans)
|
| 169 |
except Exception as e:
|
| 170 |
return server_error_response(e)
|
| 171 |
|
api/apps/system_app.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License
|
| 15 |
+
#
|
| 16 |
+
from flask_login import login_required
|
| 17 |
+
|
| 18 |
+
from api.db.services.knowledgebase_service import KnowledgebaseService
|
| 19 |
+
from api.utils.api_utils import get_json_result
|
| 20 |
+
from api.versions import get_rag_version
|
| 21 |
+
from rag.settings import SVR_QUEUE_NAME
|
| 22 |
+
from rag.utils.es_conn import ELASTICSEARCH
|
| 23 |
+
from rag.utils.minio_conn import MINIO
|
| 24 |
+
from timeit import default_timer as timer
|
| 25 |
+
|
| 26 |
+
from rag.utils.redis_conn import REDIS_CONN
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@manager.route('/version', methods=['GET'])
|
| 30 |
+
@login_required
|
| 31 |
+
def version():
|
| 32 |
+
return get_json_result(data=get_rag_version())
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@manager.route('/status', methods=['GET'])
|
| 36 |
+
@login_required
|
| 37 |
+
def status():
|
| 38 |
+
res = {}
|
| 39 |
+
st = timer()
|
| 40 |
+
try:
|
| 41 |
+
res["es"] = ELASTICSEARCH.health()
|
| 42 |
+
res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 45 |
+
|
| 46 |
+
st = timer()
|
| 47 |
+
try:
|
| 48 |
+
MINIO.health()
|
| 49 |
+
res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
| 50 |
+
except Exception as e:
|
| 51 |
+
res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 52 |
+
|
| 53 |
+
st = timer()
|
| 54 |
+
try:
|
| 55 |
+
KnowledgebaseService.get_by_id("x")
|
| 56 |
+
res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)}
|
| 57 |
+
except Exception as e:
|
| 58 |
+
res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 59 |
+
|
| 60 |
+
st = timer()
|
| 61 |
+
try:
|
| 62 |
+
qinfo = REDIS_CONN.health(SVR_QUEUE_NAME)
|
| 63 |
+
res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]}
|
| 64 |
+
except Exception as e:
|
| 65 |
+
res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)}
|
| 66 |
+
|
| 67 |
+
return get_json_result(data=res)
|
api/db/services/dialog_service.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import re
|
|
|
|
| 17 |
|
| 18 |
from api.db import LLMType
|
| 19 |
from api.db.db_models import Dialog, Conversation
|
|
@@ -71,7 +72,7 @@ def message_fit_in(msg, max_length=4000):
|
|
| 71 |
return max_length, msg
|
| 72 |
|
| 73 |
|
| 74 |
-
def chat(dialog, messages, **kwargs):
|
| 75 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 76 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
| 77 |
if not llm:
|
|
@@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs):
|
|
| 82 |
else: max_tokens = llm[0].max_tokens
|
| 83 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 84 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
| 88 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
|
@@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs):
|
|
| 94 |
if field_map:
|
| 95 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
| 96 |
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
| 97 |
-
if ans:
|
|
|
|
|
|
|
| 98 |
|
| 99 |
for p in prompt_config["parameters"]:
|
| 100 |
if p["key"] == "knowledge":
|
|
@@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs):
|
|
| 118 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 119 |
|
| 120 |
if not knowledges and prompt_config.get("empty_response"):
|
| 121 |
-
|
| 122 |
-
"answer": prompt_config["empty_response"], "reference": kbinfos}
|
|
|
|
| 123 |
|
| 124 |
kwargs["knowledge"] = "\n".join(knowledges)
|
| 125 |
gen_conf = dialog.llm_setting
|
|
@@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs):
|
|
| 130 |
gen_conf["max_tokens"] = min(
|
| 131 |
gen_conf["max_tokens"],
|
| 132 |
max_tokens - used_token_count)
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
if
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
|
| 162 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
|
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
import re
|
| 17 |
+
from copy import deepcopy
|
| 18 |
|
| 19 |
from api.db import LLMType
|
| 20 |
from api.db.db_models import Dialog, Conversation
|
|
|
|
| 72 |
return max_length, msg
|
| 73 |
|
| 74 |
|
| 75 |
+
def chat(dialog, messages, stream=True, **kwargs):
|
| 76 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 77 |
llm = LLMService.query(llm_name=dialog.llm_id)
|
| 78 |
if not llm:
|
|
|
|
| 83 |
else: max_tokens = llm[0].max_tokens
|
| 84 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 85 |
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
| 86 |
+
if len(embd_nms) != 1:
|
| 87 |
+
if stream:
|
| 88 |
+
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 89 |
+
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 90 |
|
| 91 |
questions = [m["content"] for m in messages if m["role"] == "user"]
|
| 92 |
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
|
|
|
| 98 |
if field_map:
|
| 99 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
| 100 |
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
| 101 |
+
if ans:
|
| 102 |
+
yield ans
|
| 103 |
+
return
|
| 104 |
|
| 105 |
for p in prompt_config["parameters"]:
|
| 106 |
if p["key"] == "knowledge":
|
|
|
|
| 124 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 125 |
|
| 126 |
if not knowledges and prompt_config.get("empty_response"):
|
| 127 |
+
if stream:
|
| 128 |
+
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 129 |
+
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 130 |
|
| 131 |
kwargs["knowledge"] = "\n".join(knowledges)
|
| 132 |
gen_conf = dialog.llm_setting
|
|
|
|
| 137 |
gen_conf["max_tokens"] = min(
|
| 138 |
gen_conf["max_tokens"],
|
| 139 |
max_tokens - used_token_count)
|
| 140 |
+
|
| 141 |
+
def decorate_answer(answer):
|
| 142 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos
|
| 143 |
+
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 144 |
+
answer, idx = retrievaler.insert_citations(answer,
|
| 145 |
+
[ck["content_ltks"]
|
| 146 |
+
for ck in kbinfos["chunks"]],
|
| 147 |
+
[ck["vector"]
|
| 148 |
+
for ck in kbinfos["chunks"]],
|
| 149 |
+
embd_mdl,
|
| 150 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
| 151 |
+
vtweight=dialog.vector_similarity_weight)
|
| 152 |
+
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 153 |
+
recall_docs = [
|
| 154 |
+
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
| 155 |
+
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
| 156 |
+
kbinfos["doc_aggs"] = recall_docs
|
| 157 |
+
|
| 158 |
+
refs = deepcopy(kbinfos)
|
| 159 |
+
for c in refs["chunks"]:
|
| 160 |
+
if c.get("vector"):
|
| 161 |
+
del c["vector"]
|
| 162 |
+
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0:
|
| 163 |
+
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
| 164 |
+
return {"answer": answer, "reference": refs}
|
| 165 |
+
|
| 166 |
+
if stream:
|
| 167 |
+
answer = ""
|
| 168 |
+
for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf):
|
| 169 |
+
answer = ans
|
| 170 |
+
yield {"answer": answer, "reference": {}}
|
| 171 |
+
yield decorate_answer(answer)
|
| 172 |
+
else:
|
| 173 |
+
answer = chat_mdl.chat(
|
| 174 |
+
prompt_config["system"].format(
|
| 175 |
+
**kwargs), msg, gen_conf)
|
| 176 |
+
chat_logger.info("User: {}|Assistant: {}".format(
|
| 177 |
+
msg[-1]["content"], answer))
|
| 178 |
+
return decorate_answer(answer)
|
| 179 |
|
| 180 |
|
| 181 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
api/db/services/document_service.py
CHANGED
|
@@ -43,7 +43,7 @@ class DocumentService(CommonService):
|
|
| 43 |
docs = cls.model.select().where(
|
| 44 |
(cls.model.kb_id == kb_id),
|
| 45 |
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
| 46 |
-
|
| 47 |
else:
|
| 48 |
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
| 49 |
count = docs.count()
|
|
@@ -75,7 +75,7 @@ class DocumentService(CommonService):
|
|
| 75 |
def delete(cls, doc):
|
| 76 |
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
| 77 |
if not KnowledgebaseService.update_by_id(
|
| 78 |
-
kb.id, {"doc_num": kb.doc_num - 1}):
|
| 79 |
raise RuntimeError("Database error (Knowledgebase)!")
|
| 80 |
return cls.delete_by_id(doc.id)
|
| 81 |
|
|
|
|
| 43 |
docs = cls.model.select().where(
|
| 44 |
(cls.model.kb_id == kb_id),
|
| 45 |
(fn.LOWER(cls.model.name).contains(keywords.lower()))
|
| 46 |
+
)
|
| 47 |
else:
|
| 48 |
docs = cls.model.select().where(cls.model.kb_id == kb_id)
|
| 49 |
count = docs.count()
|
|
|
|
| 75 |
def delete(cls, doc):
|
| 76 |
e, kb = KnowledgebaseService.get_by_id(doc.kb_id)
|
| 77 |
if not KnowledgebaseService.update_by_id(
|
| 78 |
+
kb.id, {"doc_num": max(0, kb.doc_num - 1)}):
|
| 79 |
raise RuntimeError("Database error (Knowledgebase)!")
|
| 80 |
return cls.delete_by_id(doc.id)
|
| 81 |
|
api/db/services/llm_service.py
CHANGED
|
@@ -172,8 +172,18 @@ class LLMBundle(object):
|
|
| 172 |
|
| 173 |
def chat(self, system, history, gen_conf):
|
| 174 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
| 175 |
-
if TenantLLMService.increase_usage(
|
| 176 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
| 177 |
database_logger.error(
|
| 178 |
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
| 179 |
return txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
def chat(self, system, history, gen_conf):
|
| 174 |
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
| 175 |
+
if not TenantLLMService.increase_usage(
|
| 176 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
| 177 |
database_logger.error(
|
| 178 |
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
| 179 |
return txt
|
| 180 |
+
|
| 181 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 182 |
+
for txt in self.mdl.chat_streamly(system, history, gen_conf):
|
| 183 |
+
if isinstance(txt, int):
|
| 184 |
+
if not TenantLLMService.increase_usage(
|
| 185 |
+
self.tenant_id, self.llm_type, txt, self.llm_name):
|
| 186 |
+
database_logger.error(
|
| 187 |
+
"Can't update token usage for {}/CHAT".format(self.tenant_id))
|
| 188 |
+
return
|
| 189 |
+
yield txt
|
api/utils/api_utils.py
CHANGED
|
@@ -25,7 +25,6 @@ from flask import (
|
|
| 25 |
from werkzeug.http import HTTP_STATUS_CODES
|
| 26 |
|
| 27 |
from api.utils import json_dumps
|
| 28 |
-
from api.versions import get_rag_version
|
| 29 |
from api.settings import RetCode
|
| 30 |
from api.settings import (
|
| 31 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
|
@@ -84,9 +83,6 @@ def request(**kwargs):
|
|
| 84 |
return sess.send(prepped, stream=stream, timeout=timeout)
|
| 85 |
|
| 86 |
|
| 87 |
-
rag_version = get_rag_version() or ''
|
| 88 |
-
|
| 89 |
-
|
| 90 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
| 91 |
"""Calculate the exponential backoff wait time."""
|
| 92 |
# Will be zero if factor equals 0
|
|
|
|
| 25 |
from werkzeug.http import HTTP_STATUS_CODES
|
| 26 |
|
| 27 |
from api.utils import json_dumps
|
|
|
|
| 28 |
from api.settings import RetCode
|
| 29 |
from api.settings import (
|
| 30 |
REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC,
|
|
|
|
| 83 |
return sess.send(prepped, stream=stream, timeout=timeout)
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
def get_exponential_backoff_interval(retries, full_jitter=False):
|
| 87 |
"""Calculate the exponential backoff wait time."""
|
| 88 |
# Will be zero if factor equals 0
|
rag/llm/chat_model.py
CHANGED
|
@@ -20,7 +20,6 @@ from openai import OpenAI
|
|
| 20 |
import openai
|
| 21 |
from ollama import Client
|
| 22 |
from rag.nlp import is_english
|
| 23 |
-
from rag.utils import num_tokens_from_string
|
| 24 |
|
| 25 |
|
| 26 |
class Base(ABC):
|
|
@@ -44,6 +43,31 @@ class Base(ABC):
|
|
| 44 |
except openai.APIError as e:
|
| 45 |
return "**ERROR**: " + str(e), 0
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
class GptTurbo(Base):
|
| 49 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
|
@@ -97,6 +121,35 @@ class QWenChat(Base):
|
|
| 97 |
|
| 98 |
return "**ERROR**: " + response.message, tk_count
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
class ZhipuChat(Base):
|
| 102 |
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
|
@@ -122,6 +175,34 @@ class ZhipuChat(Base):
|
|
| 122 |
except Exception as e:
|
| 123 |
return "**ERROR**: " + str(e), 0
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
class OllamaChat(Base):
|
| 127 |
def __init__(self, key, model_name, **kwargs):
|
|
@@ -148,3 +229,28 @@ class OllamaChat(Base):
|
|
| 148 |
except Exception as e:
|
| 149 |
return "**ERROR**: " + str(e), 0
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
import openai
|
| 21 |
from ollama import Client
|
| 22 |
from rag.nlp import is_english
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class Base(ABC):
|
|
|
|
| 43 |
except openai.APIError as e:
|
| 44 |
return "**ERROR**: " + str(e), 0
|
| 45 |
|
| 46 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 47 |
+
if system:
|
| 48 |
+
history.insert(0, {"role": "system", "content": system})
|
| 49 |
+
ans = ""
|
| 50 |
+
total_tokens = 0
|
| 51 |
+
try:
|
| 52 |
+
response = self.client.chat.completions.create(
|
| 53 |
+
model=self.model_name,
|
| 54 |
+
messages=history,
|
| 55 |
+
stream=True,
|
| 56 |
+
**gen_conf)
|
| 57 |
+
for resp in response:
|
| 58 |
+
if not resp.choices[0].delta.content:continue
|
| 59 |
+
ans += resp.choices[0].delta.content
|
| 60 |
+
total_tokens += 1
|
| 61 |
+
if resp.choices[0].finish_reason == "length":
|
| 62 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 63 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 64 |
+
yield ans
|
| 65 |
+
|
| 66 |
+
except openai.APIError as e:
|
| 67 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 68 |
+
|
| 69 |
+
yield total_tokens
|
| 70 |
+
|
| 71 |
|
| 72 |
class GptTurbo(Base):
|
| 73 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
|
|
|
| 121 |
|
| 122 |
return "**ERROR**: " + response.message, tk_count
|
| 123 |
|
| 124 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 125 |
+
from http import HTTPStatus
|
| 126 |
+
if system:
|
| 127 |
+
history.insert(0, {"role": "system", "content": system})
|
| 128 |
+
ans = ""
|
| 129 |
+
try:
|
| 130 |
+
response = Generation.call(
|
| 131 |
+
self.model_name,
|
| 132 |
+
messages=history,
|
| 133 |
+
result_format='message',
|
| 134 |
+
stream=True,
|
| 135 |
+
**gen_conf
|
| 136 |
+
)
|
| 137 |
+
tk_count = 0
|
| 138 |
+
for resp in response:
|
| 139 |
+
if resp.status_code == HTTPStatus.OK:
|
| 140 |
+
ans = resp.output.choices[0]['message']['content']
|
| 141 |
+
tk_count = resp.usage.total_tokens
|
| 142 |
+
if resp.output.choices[0].get("finish_reason", "") == "length":
|
| 143 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 144 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 145 |
+
yield ans
|
| 146 |
+
else:
|
| 147 |
+
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
| 148 |
+
except Exception as e:
|
| 149 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 150 |
+
|
| 151 |
+
yield tk_count
|
| 152 |
+
|
| 153 |
|
| 154 |
class ZhipuChat(Base):
|
| 155 |
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
|
|
|
| 175 |
except Exception as e:
|
| 176 |
return "**ERROR**: " + str(e), 0
|
| 177 |
|
| 178 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 179 |
+
if system:
|
| 180 |
+
history.insert(0, {"role": "system", "content": system})
|
| 181 |
+
if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"]
|
| 182 |
+
if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"]
|
| 183 |
+
ans = ""
|
| 184 |
+
try:
|
| 185 |
+
response = self.client.chat.completions.create(
|
| 186 |
+
model=self.model_name,
|
| 187 |
+
messages=history,
|
| 188 |
+
stream=True,
|
| 189 |
+
**gen_conf
|
| 190 |
+
)
|
| 191 |
+
tk_count = 0
|
| 192 |
+
for resp in response:
|
| 193 |
+
if not resp.choices[0].delta.content:continue
|
| 194 |
+
delta = resp.choices[0].delta.content
|
| 195 |
+
ans += delta
|
| 196 |
+
tk_count = resp.usage.total_tokens if response.usage else 0
|
| 197 |
+
if resp.output.choices[0].finish_reason == "length":
|
| 198 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 199 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 200 |
+
yield ans
|
| 201 |
+
except Exception as e:
|
| 202 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 203 |
+
|
| 204 |
+
yield tk_count
|
| 205 |
+
|
| 206 |
|
| 207 |
class OllamaChat(Base):
|
| 208 |
def __init__(self, key, model_name, **kwargs):
|
|
|
|
| 229 |
except Exception as e:
|
| 230 |
return "**ERROR**: " + str(e), 0
|
| 231 |
|
| 232 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 233 |
+
if system:
|
| 234 |
+
history.insert(0, {"role": "system", "content": system})
|
| 235 |
+
options = {}
|
| 236 |
+
if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"]
|
| 237 |
+
if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"]
|
| 238 |
+
if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"]
|
| 239 |
+
if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"]
|
| 240 |
+
if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"]
|
| 241 |
+
ans = ""
|
| 242 |
+
try:
|
| 243 |
+
response = self.client.chat(
|
| 244 |
+
model=self.model_name,
|
| 245 |
+
messages=history,
|
| 246 |
+
stream=True,
|
| 247 |
+
options=options
|
| 248 |
+
)
|
| 249 |
+
for resp in response:
|
| 250 |
+
if resp["done"]:
|
| 251 |
+
return resp["prompt_eval_count"] + resp["eval_count"]
|
| 252 |
+
ans = resp["message"]["content"]
|
| 253 |
+
yield ans
|
| 254 |
+
except Exception as e:
|
| 255 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 256 |
+
yield 0
|
rag/svr/task_executor.py
CHANGED
|
@@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1,
|
|
| 80 |
|
| 81 |
if to_page > 0:
|
| 82 |
if msg:
|
| 83 |
-
msg = f"Page({from_page+1}~{to_page+1}): " + msg
|
| 84 |
d = {"progress_msg": msg}
|
| 85 |
if prog is not None:
|
| 86 |
d["progress"] = prog
|
|
@@ -124,7 +124,7 @@ def get_minio_binary(bucket, name):
|
|
| 124 |
def build(row):
|
| 125 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
| 126 |
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
| 127 |
-
|
| 128 |
return []
|
| 129 |
|
| 130 |
callback = partial(
|
|
@@ -138,12 +138,12 @@ def build(row):
|
|
| 138 |
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
| 139 |
binary = get_minio_binary(bucket, name)
|
| 140 |
cron_logger.info(
|
| 141 |
-
"From minio({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
| 142 |
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
| 143 |
to_page=row["to_page"], lang=row["language"], callback=callback,
|
| 144 |
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
| 145 |
cron_logger.info(
|
| 146 |
-
"Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"]))
|
| 147 |
except TimeoutError as e:
|
| 148 |
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
| 149 |
cron_logger.error(
|
|
@@ -173,7 +173,7 @@ def build(row):
|
|
| 173 |
d.update(ck)
|
| 174 |
md5 = hashlib.md5()
|
| 175 |
md5.update((ck["content_with_weight"] +
|
| 176 |
-
|
| 177 |
d["_id"] = md5.hexdigest()
|
| 178 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
| 179 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
|
@@ -261,7 +261,7 @@ def main():
|
|
| 261 |
|
| 262 |
st = timer()
|
| 263 |
cks = build(r)
|
| 264 |
-
cron_logger.info("Build chunks({}): {
|
| 265 |
if cks is None:
|
| 266 |
continue
|
| 267 |
if not cks:
|
|
@@ -271,7 +271,7 @@ def main():
|
|
| 271 |
## set_progress(r["did"], -1, "ERROR: ")
|
| 272 |
callback(
|
| 273 |
msg="Finished slicing files(%d). Start to embedding the content." %
|
| 274 |
-
|
| 275 |
st = timer()
|
| 276 |
try:
|
| 277 |
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
|
@@ -279,19 +279,19 @@ def main():
|
|
| 279 |
callback(-1, "Embedding error:{}".format(str(e)))
|
| 280 |
cron_logger.error(str(e))
|
| 281 |
tk_count = 0
|
| 282 |
-
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
| 283 |
|
| 284 |
-
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st))
|
| 285 |
init_kb(r)
|
| 286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
| 287 |
st = timer()
|
| 288 |
es_r = ""
|
| 289 |
for b in range(0, len(cks), 32):
|
| 290 |
-
es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"]))
|
| 291 |
if b % 128 == 0:
|
| 292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 293 |
|
| 294 |
-
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st))
|
| 295 |
if es_r:
|
| 296 |
callback(-1, "Index failure!")
|
| 297 |
ELASTICSEARCH.deleteByQuery(
|
|
@@ -307,8 +307,7 @@ def main():
|
|
| 307 |
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
| 308 |
cron_logger.info(
|
| 309 |
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
| 310 |
-
r["id"], tk_count, len(cks), timer()-st))
|
| 311 |
-
|
| 312 |
|
| 313 |
|
| 314 |
if __name__ == "__main__":
|
|
|
|
| 80 |
|
| 81 |
if to_page > 0:
|
| 82 |
if msg:
|
| 83 |
+
msg = f"Page({from_page + 1}~{to_page + 1}): " + msg
|
| 84 |
d = {"progress_msg": msg}
|
| 85 |
if prog is not None:
|
| 86 |
d["progress"] = prog
|
|
|
|
| 124 |
def build(row):
|
| 125 |
if row["size"] > DOC_MAXIMUM_SIZE:
|
| 126 |
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
|
| 127 |
+
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
|
| 128 |
return []
|
| 129 |
|
| 130 |
callback = partial(
|
|
|
|
| 138 |
bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"])
|
| 139 |
binary = get_minio_binary(bucket, name)
|
| 140 |
cron_logger.info(
|
| 141 |
+
"From minio({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 142 |
cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"],
|
| 143 |
to_page=row["to_page"], lang=row["language"], callback=callback,
|
| 144 |
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
|
| 145 |
cron_logger.info(
|
| 146 |
+
"Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"]))
|
| 147 |
except TimeoutError as e:
|
| 148 |
callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.")
|
| 149 |
cron_logger.error(
|
|
|
|
| 173 |
d.update(ck)
|
| 174 |
md5 = hashlib.md5()
|
| 175 |
md5.update((ck["content_with_weight"] +
|
| 176 |
+
str(d["doc_id"])).encode("utf-8"))
|
| 177 |
d["_id"] = md5.hexdigest()
|
| 178 |
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
| 179 |
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
|
|
|
| 261 |
|
| 262 |
st = timer()
|
| 263 |
cks = build(r)
|
| 264 |
+
cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st))
|
| 265 |
if cks is None:
|
| 266 |
continue
|
| 267 |
if not cks:
|
|
|
|
| 271 |
## set_progress(r["did"], -1, "ERROR: ")
|
| 272 |
callback(
|
| 273 |
msg="Finished slicing files(%d). Start to embedding the content." %
|
| 274 |
+
len(cks))
|
| 275 |
st = timer()
|
| 276 |
try:
|
| 277 |
tk_count = embedding(cks, embd_mdl, r["parser_config"], callback)
|
|
|
|
| 279 |
callback(-1, "Embedding error:{}".format(str(e)))
|
| 280 |
cron_logger.error(str(e))
|
| 281 |
tk_count = 0
|
| 282 |
+
cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 283 |
|
| 284 |
+
callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st))
|
| 285 |
init_kb(r)
|
| 286 |
chunk_count = len(set([c["_id"] for c in cks]))
|
| 287 |
st = timer()
|
| 288 |
es_r = ""
|
| 289 |
for b in range(0, len(cks), 32):
|
| 290 |
+
es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"]))
|
| 291 |
if b % 128 == 0:
|
| 292 |
callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="")
|
| 293 |
|
| 294 |
+
cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st))
|
| 295 |
if es_r:
|
| 296 |
callback(-1, "Index failure!")
|
| 297 |
ELASTICSEARCH.deleteByQuery(
|
|
|
|
| 307 |
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
|
| 308 |
cron_logger.info(
|
| 309 |
"Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format(
|
| 310 |
+
r["id"], tk_count, len(cks), timer() - st))
|
|
|
|
| 311 |
|
| 312 |
|
| 313 |
if __name__ == "__main__":
|
rag/utils/es_conn.py
CHANGED
|
@@ -43,6 +43,9 @@ class ESConnection:
|
|
| 43 |
v = v["number"].split(".")[0]
|
| 44 |
return int(v) >= 7
|
| 45 |
|
|
|
|
|
|
|
|
|
|
| 46 |
def upsert(self, df, idxnm=""):
|
| 47 |
res = []
|
| 48 |
for d in df:
|
|
|
|
| 43 |
v = v["number"].split(".")[0]
|
| 44 |
return int(v) >= 7
|
| 45 |
|
| 46 |
+
def health(self):
|
| 47 |
+
return dict(self.es.cluster.health())
|
| 48 |
+
|
| 49 |
def upsert(self, df, idxnm=""):
|
| 50 |
res = []
|
| 51 |
for d in df:
|
rag/utils/minio_conn.py
CHANGED
|
@@ -34,6 +34,16 @@ class RAGFlowMinio(object):
|
|
| 34 |
del self.conn
|
| 35 |
self.conn = None
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def put(self, bucket, fnm, binary):
|
| 38 |
for _ in range(3):
|
| 39 |
try:
|
|
|
|
| 34 |
del self.conn
|
| 35 |
self.conn = None
|
| 36 |
|
| 37 |
+
def health(self):
|
| 38 |
+
bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1"
|
| 39 |
+
if not self.conn.bucket_exists(bucket):
|
| 40 |
+
self.conn.make_bucket(bucket)
|
| 41 |
+
r = self.conn.put_object(bucket, fnm,
|
| 42 |
+
BytesIO(binary),
|
| 43 |
+
len(binary)
|
| 44 |
+
)
|
| 45 |
+
return r
|
| 46 |
+
|
| 47 |
def put(self, bucket, fnm, binary):
|
| 48 |
for _ in range(3):
|
| 49 |
try:
|
rag/utils/redis_conn.py
CHANGED
|
@@ -44,6 +44,10 @@ class RedisDB:
|
|
| 44 |
logging.warning("Redis can't be connected.")
|
| 45 |
return self.REDIS
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def is_alive(self):
|
| 48 |
return self.REDIS is not None
|
| 49 |
|
|
|
|
| 44 |
logging.warning("Redis can't be connected.")
|
| 45 |
return self.REDIS
|
| 46 |
|
| 47 |
+
def health(self, queue_name):
|
| 48 |
+
self.REDIS.ping()
|
| 49 |
+
return self.REDIS.xinfo_groups(queue_name)[0]
|
| 50 |
+
|
| 51 |
def is_alive(self):
|
| 52 |
return self.REDIS is not None
|
| 53 |
|