Commit
·
75c7829
1
Parent(s):
2a6f834
Refactor ask decorator (#4116)
Browse files### What problem does this PR solve?
Refactor ask decorator
### Type of change
- [x] Refactoring
---------
Signed-off-by: jinhai <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>
- api/db/services/dialog_service.py +120 -85
- api/db/services/llm_service.py +18 -14
api/db/services/dialog_service.py
CHANGED
|
@@ -23,7 +23,7 @@ from copy import deepcopy
|
|
| 23 |
from timeit import default_timer as timer
|
| 24 |
import datetime
|
| 25 |
from datetime import timedelta
|
| 26 |
-
from api.db import LLMType, ParserType,StatusEnum
|
| 27 |
from api.db.db_models import Dialog, DB
|
| 28 |
from api.db.services.common_service import CommonService
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
@@ -41,14 +41,14 @@ class DialogService(CommonService):
|
|
| 41 |
@classmethod
|
| 42 |
@DB.connection_context()
|
| 43 |
def get_list(cls, tenant_id,
|
| 44 |
-
page_number, items_per_page, orderby, desc, id
|
| 45 |
chats = cls.model.select()
|
| 46 |
if id:
|
| 47 |
chats = chats.where(cls.model.id == id)
|
| 48 |
if name:
|
| 49 |
chats = chats.where(cls.model.name == name)
|
| 50 |
chats = chats.where(
|
| 51 |
-
|
| 52 |
& (cls.model.status == StatusEnum.VALID.value)
|
| 53 |
)
|
| 54 |
if desc:
|
|
@@ -137,25 +137,37 @@ def kb_prompt(kbinfos, max_tokens):
|
|
| 137 |
|
| 138 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if not llm:
|
| 144 |
-
|
| 145 |
-
|
|
|
|
| 146 |
if not llm:
|
| 147 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 148 |
max_tokens = 8192
|
| 149 |
else:
|
| 150 |
max_tokens = llm[0].max_tokens
|
|
|
|
|
|
|
|
|
|
| 151 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 152 |
-
|
| 153 |
-
if len(
|
| 154 |
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 155 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
| 159 |
|
| 160 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
| 161 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
@@ -165,15 +177,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 165 |
if "doc_ids" in m:
|
| 166 |
attachments.extend(m["doc_ids"])
|
| 167 |
|
| 168 |
-
|
|
|
|
|
|
|
| 169 |
if not embd_mdl:
|
| 170 |
-
raise LookupError("Embedding model(%s) not found" %
|
|
|
|
|
|
|
| 171 |
|
| 172 |
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
| 173 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
| 174 |
else:
|
| 175 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 176 |
|
|
|
|
|
|
|
| 177 |
prompt_config = dialog.prompt_config
|
| 178 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
| 179 |
tts_mdl = None
|
|
@@ -200,32 +218,35 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 200 |
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
| 201 |
else:
|
| 202 |
questions = questions[-1:]
|
| 203 |
-
|
| 204 |
-
|
| 205 |
|
| 206 |
rerank_mdl = None
|
| 207 |
if dialog.rerank_id:
|
| 208 |
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
|
|
|
| 212 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
| 213 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
| 214 |
else:
|
| 215 |
if prompt_config.get("keyword", False):
|
| 216 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
| 217 |
-
|
| 218 |
|
| 219 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
| 220 |
-
kbinfos =
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
| 226 |
logging.debug(
|
| 227 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 228 |
-
retrieval_tm = timer()
|
| 229 |
|
| 230 |
if not knowledges and prompt_config.get("empty_response"):
|
| 231 |
empty_res = prompt_config["empty_response"]
|
|
@@ -249,17 +270,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 249 |
max_tokens - used_token_count)
|
| 250 |
|
| 251 |
def decorate_answer(answer):
|
| 252 |
-
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt,
|
|
|
|
|
|
|
|
|
|
| 253 |
refs = []
|
| 254 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 255 |
-
answer, idx =
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 264 |
recall_docs = [
|
| 265 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
@@ -274,10 +298,20 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 274 |
|
| 275 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 276 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 282 |
|
| 283 |
if stream:
|
|
@@ -304,15 +338,15 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 304 |
|
| 305 |
|
| 306 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
| 307 |
-
sys_prompt = "
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
{}
|
| 312 |
|
| 313 |
-
|
| 314 |
{}
|
| 315 |
-
|
| 316 |
""".format(
|
| 317 |
index_name(tenant_id),
|
| 318 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
@@ -321,10 +355,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 321 |
tried_times = 0
|
| 322 |
|
| 323 |
def get_table():
|
| 324 |
-
nonlocal sys_prompt,
|
| 325 |
-
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content":
|
| 326 |
"temperature": 0.06})
|
| 327 |
-
logging.debug(f"{question} ==> {
|
| 328 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
| 329 |
sql = re.sub(r".*select ", "select ", sql.lower())
|
| 330 |
sql = re.sub(r" +", " ", sql)
|
|
@@ -352,21 +386,23 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 352 |
if tbl is None:
|
| 353 |
return None
|
| 354 |
if tbl.get("error") and tried_times <= 2:
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
{}
|
| 359 |
-
|
| 360 |
-
|
| 361 |
{}
|
|
|
|
|
|
|
| 362 |
|
| 363 |
-
|
| 364 |
{}
|
| 365 |
|
| 366 |
-
|
| 367 |
{}
|
| 368 |
|
| 369 |
-
|
| 370 |
""".format(
|
| 371 |
index_name(tenant_id),
|
| 372 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
@@ -381,21 +417,21 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 381 |
|
| 382 |
docid_idx = set([ii for ii, c in enumerate(
|
| 383 |
tbl["columns"]) if c["name"] == "doc_id"])
|
| 384 |
-
|
| 385 |
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
| 386 |
-
|
| 387 |
-
len(tbl["columns"])) if ii not in (docid_idx |
|
| 388 |
|
| 389 |
-
# compose
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
|
| 394 |
-
line = "|" + "|".join(["------" for _ in range(len(
|
| 395 |
("|------|" if docid_idx and docid_idx else "")
|
| 396 |
|
| 397 |
rows = ["|" +
|
| 398 |
-
"|".join([rmSpace(str(r[i])) for i in
|
| 399 |
"|" for r in tbl["rows"]]
|
| 400 |
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
| 401 |
if quota:
|
|
@@ -404,24 +440,24 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
| 404 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
| 405 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
| 406 |
|
| 407 |
-
if not docid_idx or not
|
| 408 |
logging.warning("SQL missing field: " + sql)
|
| 409 |
return {
|
| 410 |
-
"answer": "\n".join([
|
| 411 |
"reference": {"chunks": [], "doc_aggs": []},
|
| 412 |
"prompt": sys_prompt
|
| 413 |
}
|
| 414 |
|
| 415 |
docid_idx = list(docid_idx)[0]
|
| 416 |
-
|
| 417 |
doc_aggs = {}
|
| 418 |
for r in tbl["rows"]:
|
| 419 |
if r[docid_idx] not in doc_aggs:
|
| 420 |
-
doc_aggs[r[docid_idx]] = {"doc_name": r[
|
| 421 |
doc_aggs[r[docid_idx]]["count"] += 1
|
| 422 |
return {
|
| 423 |
-
"answer": "\n".join([
|
| 424 |
-
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[
|
| 425 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
| 426 |
doc_aggs.items()]},
|
| 427 |
"prompt": sys_prompt
|
|
@@ -492,7 +528,7 @@ Requirements:
|
|
| 492 |
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
| 493 |
if isinstance(kwd, tuple):
|
| 494 |
kwd = kwd[0]
|
| 495 |
-
if kwd.find("**ERROR**") >=0:
|
| 496 |
return ""
|
| 497 |
return kwd
|
| 498 |
|
|
@@ -605,16 +641,16 @@ def tts(tts_mdl, text):
|
|
| 605 |
|
| 606 |
def ask(question, kb_ids, tenant_id):
|
| 607 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
| 608 |
-
|
| 609 |
|
| 610 |
-
|
| 611 |
-
|
| 612 |
|
| 613 |
-
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING,
|
| 614 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
| 615 |
max_tokens = chat_mdl.max_length
|
| 616 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
| 617 |
-
kbinfos =
|
| 618 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
| 619 |
prompt = """
|
| 620 |
Role: You're a smart assistant. Your name is Miss R.
|
|
@@ -636,14 +672,14 @@ def ask(question, kb_ids, tenant_id):
|
|
| 636 |
|
| 637 |
def decorate_answer(answer):
|
| 638 |
nonlocal knowledges, kbinfos, prompt
|
| 639 |
-
answer, idx =
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 648 |
recall_docs = [
|
| 649 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
@@ -664,4 +700,3 @@ def ask(question, kb_ids, tenant_id):
|
|
| 664 |
answer = ans
|
| 665 |
yield {"answer": answer, "reference": {}}
|
| 666 |
yield decorate_answer(answer)
|
| 667 |
-
|
|
|
|
| 23 |
from timeit import default_timer as timer
|
| 24 |
import datetime
|
| 25 |
from datetime import timedelta
|
| 26 |
+
from api.db import LLMType, ParserType, StatusEnum
|
| 27 |
from api.db.db_models import Dialog, DB
|
| 28 |
from api.db.services.common_service import CommonService
|
| 29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
|
| 41 |
@classmethod
|
| 42 |
@DB.connection_context()
|
| 43 |
def get_list(cls, tenant_id,
|
| 44 |
+
page_number, items_per_page, orderby, desc, id, name):
|
| 45 |
chats = cls.model.select()
|
| 46 |
if id:
|
| 47 |
chats = chats.where(cls.model.id == id)
|
| 48 |
if name:
|
| 49 |
chats = chats.where(cls.model.name == name)
|
| 50 |
chats = chats.where(
|
| 51 |
+
(cls.model.tenant_id == tenant_id)
|
| 52 |
& (cls.model.status == StatusEnum.VALID.value)
|
| 53 |
)
|
| 54 |
if desc:
|
|
|
|
| 137 |
|
| 138 |
def chat(dialog, messages, stream=True, **kwargs):
|
| 139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
| 140 |
+
|
| 141 |
+
chat_start_ts = timer()
|
| 142 |
+
|
| 143 |
+
# Get llm model name and model provider name
|
| 144 |
+
llm_id, model_provider = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
|
| 145 |
+
|
| 146 |
+
# Get llm model instance by model and provide name
|
| 147 |
+
llm = LLMService.query(llm_name=llm_id) if not model_provider else LLMService.query(llm_name=llm_id, fid=model_provider)
|
| 148 |
+
|
| 149 |
if not llm:
|
| 150 |
+
# Model name is provided by tenant, but not system built-in
|
| 151 |
+
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not model_provider else \
|
| 152 |
+
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=model_provider)
|
| 153 |
if not llm:
|
| 154 |
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
| 155 |
max_tokens = 8192
|
| 156 |
else:
|
| 157 |
max_tokens = llm[0].max_tokens
|
| 158 |
+
|
| 159 |
+
check_llm_ts = timer()
|
| 160 |
+
|
| 161 |
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
| 162 |
+
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
| 163 |
+
if len(embedding_list) != 1:
|
| 164 |
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 165 |
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
| 166 |
|
| 167 |
+
embedding_model_name = embedding_list[0]
|
| 168 |
+
|
| 169 |
+
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 170 |
+
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
| 171 |
|
| 172 |
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
| 173 |
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
|
|
|
| 177 |
if "doc_ids" in m:
|
| 178 |
attachments.extend(m["doc_ids"])
|
| 179 |
|
| 180 |
+
create_retriever_ts = timer()
|
| 181 |
+
|
| 182 |
+
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
|
| 183 |
if not embd_mdl:
|
| 184 |
+
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
|
| 185 |
+
|
| 186 |
+
bind_embedding_ts = timer()
|
| 187 |
|
| 188 |
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
| 189 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
| 190 |
else:
|
| 191 |
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
| 192 |
|
| 193 |
+
bind_llm_ts = timer()
|
| 194 |
+
|
| 195 |
prompt_config = dialog.prompt_config
|
| 196 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
| 197 |
tts_mdl = None
|
|
|
|
| 218 |
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
| 219 |
else:
|
| 220 |
questions = questions[-1:]
|
| 221 |
+
|
| 222 |
+
refine_question_ts = timer()
|
| 223 |
|
| 224 |
rerank_mdl = None
|
| 225 |
if dialog.rerank_id:
|
| 226 |
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
| 227 |
|
| 228 |
+
bind_reranker_ts = timer()
|
| 229 |
+
generate_keyword_ts = bind_reranker_ts
|
| 230 |
+
|
| 231 |
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
| 232 |
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
| 233 |
else:
|
| 234 |
if prompt_config.get("keyword", False):
|
| 235 |
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
| 236 |
+
generate_keyword_ts = timer()
|
| 237 |
|
| 238 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
| 239 |
+
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
| 240 |
+
dialog.similarity_threshold,
|
| 241 |
+
dialog.vector_similarity_weight,
|
| 242 |
+
doc_ids=attachments,
|
| 243 |
+
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
| 244 |
+
|
| 245 |
+
retrieval_ts = timer()
|
| 246 |
+
|
| 247 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
| 248 |
logging.debug(
|
| 249 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
|
|
|
| 250 |
|
| 251 |
if not knowledges and prompt_config.get("empty_response"):
|
| 252 |
empty_res = prompt_config["empty_response"]
|
|
|
|
| 270 |
max_tokens - used_token_count)
|
| 271 |
|
| 272 |
def decorate_answer(answer):
|
| 273 |
+
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts
|
| 274 |
+
|
| 275 |
+
finish_chat_ts = timer()
|
| 276 |
+
|
| 277 |
refs = []
|
| 278 |
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
| 279 |
+
answer, idx = retriever.insert_citations(answer,
|
| 280 |
+
[ck["content_ltks"]
|
| 281 |
+
for ck in kbinfos["chunks"]],
|
| 282 |
+
[ck["vector"]
|
| 283 |
+
for ck in kbinfos["chunks"]],
|
| 284 |
+
embd_mdl,
|
| 285 |
+
tkweight=1 - dialog.vector_similarity_weight,
|
| 286 |
+
vtweight=dialog.vector_similarity_weight)
|
| 287 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 288 |
recall_docs = [
|
| 289 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
|
| 298 |
|
| 299 |
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
| 300 |
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
| 301 |
+
finish_chat_ts = timer()
|
| 302 |
+
|
| 303 |
+
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
| 304 |
+
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
| 305 |
+
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
|
| 306 |
+
bind_embedding_time_cost = (bind_embedding_ts - create_retriever_ts) * 1000
|
| 307 |
+
bind_llm_time_cost = (bind_llm_ts - bind_embedding_ts) * 1000
|
| 308 |
+
refine_question_time_cost = (refine_question_ts - bind_llm_ts) * 1000
|
| 309 |
+
bind_reranker_time_cost = (bind_reranker_ts - refine_question_ts) * 1000
|
| 310 |
+
generate_keyword_time_cost = (generate_keyword_ts - bind_reranker_ts) * 1000
|
| 311 |
+
retrieval_time_cost = (retrieval_ts - generate_keyword_ts) * 1000
|
| 312 |
+
generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000
|
| 313 |
+
|
| 314 |
+
prompt = f"{prompt} ### Elapsed\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
| 315 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 316 |
|
| 317 |
if stream:
|
|
|
|
| 338 |
|
| 339 |
|
| 340 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
| 341 |
+
sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question."
|
| 342 |
+
user_prompt = """
|
| 343 |
+
Table name: {};
|
| 344 |
+
Table of database fields are as follows:
|
| 345 |
{}
|
| 346 |
|
| 347 |
+
Question are as follows:
|
| 348 |
{}
|
| 349 |
+
Please write the SQL, only SQL, without any other explanations or text.
|
| 350 |
""".format(
|
| 351 |
index_name(tenant_id),
|
| 352 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
|
|
| 355 |
tried_times = 0
|
| 356 |
|
| 357 |
def get_table():
|
| 358 |
+
nonlocal sys_prompt, user_prompt, question, tried_times
|
| 359 |
+
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
| 360 |
"temperature": 0.06})
|
| 361 |
+
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
| 362 |
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
| 363 |
sql = re.sub(r".*select ", "select ", sql.lower())
|
| 364 |
sql = re.sub(r" +", " ", sql)
|
|
|
|
| 386 |
if tbl is None:
|
| 387 |
return None
|
| 388 |
if tbl.get("error") and tried_times <= 2:
|
| 389 |
+
user_prompt = """
|
| 390 |
+
Table name: {};
|
| 391 |
+
Table of database fields are as follows:
|
| 392 |
{}
|
| 393 |
+
|
| 394 |
+
Question are as follows:
|
| 395 |
{}
|
| 396 |
+
Please write the SQL, only SQL, without any other explanations or text.
|
| 397 |
+
|
| 398 |
|
| 399 |
+
The SQL error you provided last time is as follows:
|
| 400 |
{}
|
| 401 |
|
| 402 |
+
Error issued by database as follows:
|
| 403 |
{}
|
| 404 |
|
| 405 |
+
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
| 406 |
""".format(
|
| 407 |
index_name(tenant_id),
|
| 408 |
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
|
|
|
| 417 |
|
| 418 |
docid_idx = set([ii for ii, c in enumerate(
|
| 419 |
tbl["columns"]) if c["name"] == "doc_id"])
|
| 420 |
+
doc_name_idx = set([ii for ii, c in enumerate(
|
| 421 |
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
| 422 |
+
column_idx = [ii for ii in range(
|
| 423 |
+
len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)]
|
| 424 |
|
| 425 |
+
# compose Markdown table
|
| 426 |
+
columns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
| 427 |
+
tbl["columns"][i]["name"])) for i in
|
| 428 |
+
column_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
| 429 |
|
| 430 |
+
line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + \
|
| 431 |
("|------|" if docid_idx and docid_idx else "")
|
| 432 |
|
| 433 |
rows = ["|" +
|
| 434 |
+
"|".join([rmSpace(str(r[i])) for i in column_idx]).replace("None", " ") +
|
| 435 |
"|" for r in tbl["rows"]]
|
| 436 |
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
| 437 |
if quota:
|
|
|
|
| 440 |
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
| 441 |
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
| 442 |
|
| 443 |
+
if not docid_idx or not doc_name_idx:
|
| 444 |
logging.warning("SQL missing field: " + sql)
|
| 445 |
return {
|
| 446 |
+
"answer": "\n".join([columns, line, rows]),
|
| 447 |
"reference": {"chunks": [], "doc_aggs": []},
|
| 448 |
"prompt": sys_prompt
|
| 449 |
}
|
| 450 |
|
| 451 |
docid_idx = list(docid_idx)[0]
|
| 452 |
+
doc_name_idx = list(doc_name_idx)[0]
|
| 453 |
doc_aggs = {}
|
| 454 |
for r in tbl["rows"]:
|
| 455 |
if r[docid_idx] not in doc_aggs:
|
| 456 |
+
doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0}
|
| 457 |
doc_aggs[r[docid_idx]]["count"] += 1
|
| 458 |
return {
|
| 459 |
+
"answer": "\n".join([columns, line, rows]),
|
| 460 |
+
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]],
|
| 461 |
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
| 462 |
doc_aggs.items()]},
|
| 463 |
"prompt": sys_prompt
|
|
|
|
| 528 |
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
| 529 |
if isinstance(kwd, tuple):
|
| 530 |
kwd = kwd[0]
|
| 531 |
+
if kwd.find("**ERROR**") >= 0:
|
| 532 |
return ""
|
| 533 |
return kwd
|
| 534 |
|
|
|
|
| 641 |
|
| 642 |
def ask(question, kb_ids, tenant_id):
|
| 643 |
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
| 644 |
+
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
| 645 |
|
| 646 |
+
is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs])
|
| 647 |
+
retriever = settings.retrievaler if not is_knowledge_graph else settings.kg_retrievaler
|
| 648 |
|
| 649 |
+
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
|
| 650 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
| 651 |
max_tokens = chat_mdl.max_length
|
| 652 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
| 653 |
+
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
| 654 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
| 655 |
prompt = """
|
| 656 |
Role: You're a smart assistant. Your name is Miss R.
|
|
|
|
| 672 |
|
| 673 |
def decorate_answer(answer):
|
| 674 |
nonlocal knowledges, kbinfos, prompt
|
| 675 |
+
answer, idx = retriever.insert_citations(answer,
|
| 676 |
+
[ck["content_ltks"]
|
| 677 |
+
for ck in kbinfos["chunks"]],
|
| 678 |
+
[ck["vector"]
|
| 679 |
+
for ck in kbinfos["chunks"]],
|
| 680 |
+
embd_mdl,
|
| 681 |
+
tkweight=0.7,
|
| 682 |
+
vtweight=0.3)
|
| 683 |
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
| 684 |
recall_docs = [
|
| 685 |
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
|
|
|
| 700 |
answer = ans
|
| 701 |
yield {"answer": answer, "reference": {}}
|
| 702 |
yield decorate_answer(answer)
|
|
|
api/db/services/llm_service.py
CHANGED
|
@@ -72,10 +72,12 @@ class TenantLLMService(CommonService):
|
|
| 72 |
return model_name, None
|
| 73 |
if len(arr) > 2:
|
| 74 |
return "@".join(arr[0:-1]), arr[-1]
|
|
|
|
|
|
|
| 75 |
try:
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
if arr[-1] not in
|
| 79 |
return model_name, None
|
| 80 |
return arr[0], arr[-1]
|
| 81 |
except Exception as e:
|
|
@@ -113,11 +115,11 @@ class TenantLLMService(CommonService):
|
|
| 113 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
| 114 |
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
| 115 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
| 116 |
-
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": mdlnm, "api_base": ""}
|
| 117 |
if not model_config:
|
| 118 |
if mdlnm == "flag-embedding":
|
| 119 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
| 120 |
-
|
| 121 |
else:
|
| 122 |
if not mdlnm:
|
| 123 |
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
@@ -200,8 +202,8 @@ class TenantLLMService(CommonService):
|
|
| 200 |
return num
|
| 201 |
else:
|
| 202 |
tenant_llm = tenant_llms[0]
|
| 203 |
-
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens)\
|
| 204 |
-
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
|
| 205 |
.execute()
|
| 206 |
except Exception:
|
| 207 |
logging.exception("TenantLLMService.increase_usage got exception")
|
|
@@ -231,7 +233,7 @@ class LLMBundle(object):
|
|
| 231 |
for lm in LLMService.query(llm_name=llm_name):
|
| 232 |
self.max_length = lm.max_tokens
|
| 233 |
break
|
| 234 |
-
|
| 235 |
def encode(self, texts: list):
|
| 236 |
embeddings, used_tokens = self.mdl.encode(texts)
|
| 237 |
if not TenantLLMService.increase_usage(
|
|
@@ -274,11 +276,11 @@ class LLMBundle(object):
|
|
| 274 |
|
| 275 |
def tts(self, text):
|
| 276 |
for chunk in self.mdl.tts(text):
|
| 277 |
-
if isinstance(chunk,int):
|
| 278 |
if not TenantLLMService.increase_usage(
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
return
|
| 283 |
yield chunk
|
| 284 |
|
|
@@ -287,7 +289,8 @@ class LLMBundle(object):
|
|
| 287 |
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
| 288 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
| 289 |
logging.error(
|
| 290 |
-
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
|
|
|
| 291 |
return txt
|
| 292 |
|
| 293 |
def chat_streamly(self, system, history, gen_conf):
|
|
@@ -296,6 +299,7 @@ class LLMBundle(object):
|
|
| 296 |
if not TenantLLMService.increase_usage(
|
| 297 |
self.tenant_id, self.llm_type, txt, self.llm_name):
|
| 298 |
logging.error(
|
| 299 |
-
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
|
|
|
| 300 |
return
|
| 301 |
yield txt
|
|
|
|
| 72 |
return model_name, None
|
| 73 |
if len(arr) > 2:
|
| 74 |
return "@".join(arr[0:-1]), arr[-1]
|
| 75 |
+
|
| 76 |
+
# model name must be xxx@yyy
|
| 77 |
try:
|
| 78 |
+
model_factories = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
|
| 79 |
+
model_providers = set([f["name"] for f in model_factories])
|
| 80 |
+
if arr[-1] not in model_providers:
|
| 81 |
return model_name, None
|
| 82 |
return arr[0], arr[-1]
|
| 83 |
except Exception as e:
|
|
|
|
| 115 |
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
| 116 |
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
| 117 |
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
| 118 |
+
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
| 119 |
if not model_config:
|
| 120 |
if mdlnm == "flag-embedding":
|
| 121 |
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "",
|
| 122 |
+
"llm_name": llm_name, "api_base": ""}
|
| 123 |
else:
|
| 124 |
if not mdlnm:
|
| 125 |
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
|
|
| 202 |
return num
|
| 203 |
else:
|
| 204 |
tenant_llm = tenant_llms[0]
|
| 205 |
+
num = cls.model.update(used_tokens=tenant_llm.used_tokens + used_tokens) \
|
| 206 |
+
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name) \
|
| 207 |
.execute()
|
| 208 |
except Exception:
|
| 209 |
logging.exception("TenantLLMService.increase_usage got exception")
|
|
|
|
| 233 |
for lm in LLMService.query(llm_name=llm_name):
|
| 234 |
self.max_length = lm.max_tokens
|
| 235 |
break
|
| 236 |
+
|
| 237 |
def encode(self, texts: list):
|
| 238 |
embeddings, used_tokens = self.mdl.encode(texts)
|
| 239 |
if not TenantLLMService.increase_usage(
|
|
|
|
| 276 |
|
| 277 |
def tts(self, text):
|
| 278 |
for chunk in self.mdl.tts(text):
|
| 279 |
+
if isinstance(chunk, int):
|
| 280 |
if not TenantLLMService.increase_usage(
|
| 281 |
+
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
| 282 |
+
logging.error(
|
| 283 |
+
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
| 284 |
return
|
| 285 |
yield chunk
|
| 286 |
|
|
|
|
| 289 |
if isinstance(txt, int) and not TenantLLMService.increase_usage(
|
| 290 |
self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
| 291 |
logging.error(
|
| 292 |
+
"LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name,
|
| 293 |
+
used_tokens))
|
| 294 |
return txt
|
| 295 |
|
| 296 |
def chat_streamly(self, system, history, gen_conf):
|
|
|
|
| 299 |
if not TenantLLMService.increase_usage(
|
| 300 |
self.tenant_id, self.llm_type, txt, self.llm_name):
|
| 301 |
logging.error(
|
| 302 |
+
"LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name,
|
| 303 |
+
txt))
|
| 304 |
return
|
| 305 |
yield txt
|