liuhua
liuhua
commited on
Commit
·
678763e
1
Parent(s):
5000eb5
Fix: renrank_model and pdf_parser bugs | Update: session API (#2601)
Browse files### What problem does this PR solve?
Fix: renrank_model and pdf_parser bugs | Update: session API
#2575
#2559
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Refactoring
---------
Co-authored-by: liuhua <[email protected]>
api/apps/sdk/session.py
CHANGED
|
@@ -87,9 +87,9 @@ def completion(tenant_id):
|
|
| 87 |
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
| 88 |
# {"role": "user", "content": "上海有吗?"}
|
| 89 |
# ]}
|
| 90 |
-
if "
|
| 91 |
-
return get_data_error_result(retmsg="
|
| 92 |
-
conv = ConversationService.query(id=req["
|
| 93 |
if not conv:
|
| 94 |
return get_data_error_result(retmsg="Session does not exist")
|
| 95 |
conv = conv[0]
|
|
@@ -108,7 +108,7 @@ def completion(tenant_id):
|
|
| 108 |
msg.append(m)
|
| 109 |
message_id = msg[-1].get("id")
|
| 110 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 111 |
-
del req["
|
| 112 |
|
| 113 |
if not conv.reference:
|
| 114 |
conv.reference = []
|
|
@@ -168,6 +168,9 @@ def get(tenant_id):
|
|
| 168 |
return get_data_error_result(retmsg="Session does not exist")
|
| 169 |
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
| 170 |
return get_data_error_result(retmsg="You do not own the session")
|
|
|
|
|
|
|
|
|
|
| 171 |
conv = conv[0].to_dict()
|
| 172 |
conv['messages'] = conv.pop("message")
|
| 173 |
conv["assistant_id"] = conv.pop("dialog_id")
|
|
@@ -207,7 +210,7 @@ def list(tenant_id):
|
|
| 207 |
assistant_id = request.args["assistant_id"]
|
| 208 |
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
|
| 209 |
return get_json_result(
|
| 210 |
-
data=False, retmsg=f'
|
| 211 |
retcode=RetCode.OPERATING_ERROR)
|
| 212 |
convs = ConversationService.query(
|
| 213 |
dialog_id=assistant_id,
|
|
|
|
| 87 |
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
|
| 88 |
# {"role": "user", "content": "上海有吗?"}
|
| 89 |
# ]}
|
| 90 |
+
if "session_id" not in req:
|
| 91 |
+
return get_data_error_result(retmsg="session_id is required")
|
| 92 |
+
conv = ConversationService.query(id=req["session_id"])
|
| 93 |
if not conv:
|
| 94 |
return get_data_error_result(retmsg="Session does not exist")
|
| 95 |
conv = conv[0]
|
|
|
|
| 108 |
msg.append(m)
|
| 109 |
message_id = msg[-1].get("id")
|
| 110 |
e, dia = DialogService.get_by_id(conv.dialog_id)
|
| 111 |
+
del req["session_id"]
|
| 112 |
|
| 113 |
if not conv.reference:
|
| 114 |
conv.reference = []
|
|
|
|
| 168 |
return get_data_error_result(retmsg="Session does not exist")
|
| 169 |
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
| 170 |
return get_data_error_result(retmsg="You do not own the session")
|
| 171 |
+
if "assistant_id" in req:
|
| 172 |
+
if req["assistant_id"] != conv[0].dialog_id:
|
| 173 |
+
return get_data_error_result(retmsg="The session doesn't belong to the assistant")
|
| 174 |
conv = conv[0].to_dict()
|
| 175 |
conv['messages'] = conv.pop("message")
|
| 176 |
conv["assistant_id"] = conv.pop("dialog_id")
|
|
|
|
| 210 |
assistant_id = request.args["assistant_id"]
|
| 211 |
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
|
| 212 |
return get_json_result(
|
| 213 |
+
data=False, retmsg=f"You don't own the assistant.",
|
| 214 |
retcode=RetCode.OPERATING_ERROR)
|
| 215 |
convs = ConversationService.query(
|
| 216 |
dialog_id=assistant_id,
|
deepdoc/parser/pdf_parser.py
CHANGED
|
@@ -488,7 +488,7 @@ class RAGFlowPdfParser:
|
|
| 488 |
i += 1
|
| 489 |
continue
|
| 490 |
|
| 491 |
-
if not down["text"].strip():
|
| 492 |
i += 1
|
| 493 |
continue
|
| 494 |
|
|
|
|
| 488 |
i += 1
|
| 489 |
continue
|
| 490 |
|
| 491 |
+
if not down["text"].strip() or not up["text"].strip():
|
| 492 |
i += 1
|
| 493 |
continue
|
| 494 |
|
rag/llm/rerank_model.py
CHANGED
|
@@ -26,9 +26,11 @@ from api.utils.file_utils import get_home_cache_dir
|
|
| 26 |
from rag.utils import num_tokens_from_string, truncate
|
| 27 |
import json
|
| 28 |
|
|
|
|
| 29 |
def sigmoid(x):
|
| 30 |
return 1 / (1 + np.exp(-x))
|
| 31 |
|
|
|
|
| 32 |
class Base(ABC):
|
| 33 |
def __init__(self, key, model_name):
|
| 34 |
pass
|
|
@@ -59,16 +61,19 @@ class DefaultRerank(Base):
|
|
| 59 |
with DefaultRerank._model_lock:
|
| 60 |
if not DefaultRerank._model:
|
| 61 |
try:
|
| 62 |
-
DefaultRerank._model = FlagReranker(
|
|
|
|
|
|
|
| 63 |
except Exception as e:
|
| 64 |
-
model_dir = snapshot_download(repo_id=
|
| 65 |
-
local_dir=os.path.join(get_home_cache_dir(),
|
|
|
|
| 66 |
local_dir_use_symlinks=False)
|
| 67 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
| 68 |
self._model = DefaultRerank._model
|
| 69 |
|
| 70 |
def similarity(self, query: str, texts: list):
|
| 71 |
-
pairs = [(query,truncate(t, 2048)) for t in texts]
|
| 72 |
token_count = 0
|
| 73 |
for _, t in pairs:
|
| 74 |
token_count += num_tokens_from_string(t)
|
|
@@ -77,8 +82,10 @@ class DefaultRerank(Base):
|
|
| 77 |
for i in range(0, len(pairs), batch_size):
|
| 78 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
| 79 |
scores = sigmoid(np.array(scores)).tolist()
|
| 80 |
-
if isinstance(scores, float):
|
| 81 |
-
|
|
|
|
|
|
|
| 82 |
return np.array(res), token_count
|
| 83 |
|
| 84 |
|
|
@@ -101,7 +108,10 @@ class JinaRerank(Base):
|
|
| 101 |
"top_n": len(texts)
|
| 102 |
}
|
| 103 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
|
| 107 |
class YoudaoRerank(DefaultRerank):
|
|
@@ -124,7 +134,7 @@ class YoudaoRerank(DefaultRerank):
|
|
| 124 |
"maidalun1020", "InfiniFlow"))
|
| 125 |
|
| 126 |
self._model = YoudaoRerank._model
|
| 127 |
-
|
| 128 |
def similarity(self, query: str, texts: list):
|
| 129 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
| 130 |
token_count = 0
|
|
@@ -135,8 +145,10 @@ class YoudaoRerank(DefaultRerank):
|
|
| 135 |
for i in range(0, len(pairs), batch_size):
|
| 136 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
|
| 137 |
scores = sigmoid(np.array(scores)).tolist()
|
| 138 |
-
if isinstance(scores, float):
|
| 139 |
-
|
|
|
|
|
|
|
| 140 |
return np.array(res), token_count
|
| 141 |
|
| 142 |
|
|
@@ -162,7 +174,10 @@ class XInferenceRerank(Base):
|
|
| 162 |
"documents": texts
|
| 163 |
}
|
| 164 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
class LocalAIRerank(Base):
|
|
@@ -175,7 +190,7 @@ class LocalAIRerank(Base):
|
|
| 175 |
|
| 176 |
class NvidiaRerank(Base):
|
| 177 |
def __init__(
|
| 178 |
-
|
| 179 |
):
|
| 180 |
if not base_url:
|
| 181 |
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
@@ -208,9 +223,10 @@ class NvidiaRerank(Base):
|
|
| 208 |
"top_n": len(texts),
|
| 209 |
}
|
| 210 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 211 |
-
rank = np.
|
| 212 |
-
|
| 213 |
-
|
|
|
|
| 214 |
|
| 215 |
|
| 216 |
class LmStudioRerank(Base):
|
|
@@ -247,9 +263,10 @@ class CoHereRerank(Base):
|
|
| 247 |
top_n=len(texts),
|
| 248 |
return_documents=False,
|
| 249 |
)
|
| 250 |
-
rank = np.
|
| 251 |
-
|
| 252 |
-
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
class TogetherAIRerank(Base):
|
|
@@ -262,7 +279,7 @@ class TogetherAIRerank(Base):
|
|
| 262 |
|
| 263 |
class SILICONFLOWRerank(Base):
|
| 264 |
def __init__(
|
| 265 |
-
|
| 266 |
):
|
| 267 |
if not base_url:
|
| 268 |
base_url = "https://api.siliconflow.cn/v1/rerank"
|
|
@@ -287,10 +304,11 @@ class SILICONFLOWRerank(Base):
|
|
| 287 |
response = requests.post(
|
| 288 |
self.base_url, json=payload, headers=self.headers
|
| 289 |
).json()
|
| 290 |
-
rank = np.
|
| 291 |
-
|
|
|
|
| 292 |
return (
|
| 293 |
-
rank
|
| 294 |
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
| 295 |
)
|
| 296 |
|
|
@@ -312,9 +330,10 @@ class BaiduYiyanRerank(Base):
|
|
| 312 |
documents=texts,
|
| 313 |
top_n=len(texts),
|
| 314 |
).body
|
| 315 |
-
rank = np.
|
| 316 |
-
|
| 317 |
-
|
|
|
|
| 318 |
|
| 319 |
|
| 320 |
class VoyageRerank(Base):
|
|
@@ -328,6 +347,7 @@ class VoyageRerank(Base):
|
|
| 328 |
res = self.client.rerank(
|
| 329 |
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
| 330 |
)
|
| 331 |
-
rank = np.
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
| 26 |
from rag.utils import num_tokens_from_string, truncate
|
| 27 |
import json
|
| 28 |
|
| 29 |
+
|
| 30 |
def sigmoid(x):
|
| 31 |
return 1 / (1 + np.exp(-x))
|
| 32 |
|
| 33 |
+
|
| 34 |
class Base(ABC):
|
| 35 |
def __init__(self, key, model_name):
|
| 36 |
pass
|
|
|
|
| 61 |
with DefaultRerank._model_lock:
|
| 62 |
if not DefaultRerank._model:
|
| 63 |
try:
|
| 64 |
+
DefaultRerank._model = FlagReranker(
|
| 65 |
+
os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
| 66 |
+
use_fp16=torch.cuda.is_available())
|
| 67 |
except Exception as e:
|
| 68 |
+
model_dir = snapshot_download(repo_id=model_name,
|
| 69 |
+
local_dir=os.path.join(get_home_cache_dir(),
|
| 70 |
+
re.sub(r"^[a-zA-Z]+/", "", model_name)),
|
| 71 |
local_dir_use_symlinks=False)
|
| 72 |
DefaultRerank._model = FlagReranker(model_dir, use_fp16=torch.cuda.is_available())
|
| 73 |
self._model = DefaultRerank._model
|
| 74 |
|
| 75 |
def similarity(self, query: str, texts: list):
|
| 76 |
+
pairs = [(query, truncate(t, 2048)) for t in texts]
|
| 77 |
token_count = 0
|
| 78 |
for _, t in pairs:
|
| 79 |
token_count += num_tokens_from_string(t)
|
|
|
|
| 82 |
for i in range(0, len(pairs), batch_size):
|
| 83 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=2048)
|
| 84 |
scores = sigmoid(np.array(scores)).tolist()
|
| 85 |
+
if isinstance(scores, float):
|
| 86 |
+
res.append(scores)
|
| 87 |
+
else:
|
| 88 |
+
res.extend(scores)
|
| 89 |
return np.array(res), token_count
|
| 90 |
|
| 91 |
|
|
|
|
| 108 |
"top_n": len(texts)
|
| 109 |
}
|
| 110 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 111 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 112 |
+
for d in res["results"]:
|
| 113 |
+
rank[d["index"]] = d["relevance_score"]
|
| 114 |
+
return rank, res["usage"]["total_tokens"]
|
| 115 |
|
| 116 |
|
| 117 |
class YoudaoRerank(DefaultRerank):
|
|
|
|
| 134 |
"maidalun1020", "InfiniFlow"))
|
| 135 |
|
| 136 |
self._model = YoudaoRerank._model
|
| 137 |
+
|
| 138 |
def similarity(self, query: str, texts: list):
|
| 139 |
pairs = [(query, truncate(t, self._model.max_length)) for t in texts]
|
| 140 |
token_count = 0
|
|
|
|
| 145 |
for i in range(0, len(pairs), batch_size):
|
| 146 |
scores = self._model.compute_score(pairs[i:i + batch_size], max_length=self._model.max_length)
|
| 147 |
scores = sigmoid(np.array(scores)).tolist()
|
| 148 |
+
if isinstance(scores, float):
|
| 149 |
+
res.append(scores)
|
| 150 |
+
else:
|
| 151 |
+
res.extend(scores)
|
| 152 |
return np.array(res), token_count
|
| 153 |
|
| 154 |
|
|
|
|
| 174 |
"documents": texts
|
| 175 |
}
|
| 176 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 177 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 178 |
+
for d in res["results"]:
|
| 179 |
+
rank[d["index"]] = d["relevance_score"]
|
| 180 |
+
return rank, res["meta"]["tokens"]["input_tokens"] + res["meta"]["tokens"]["output_tokens"]
|
| 181 |
|
| 182 |
|
| 183 |
class LocalAIRerank(Base):
|
|
|
|
| 190 |
|
| 191 |
class NvidiaRerank(Base):
|
| 192 |
def __init__(
|
| 193 |
+
self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
| 194 |
):
|
| 195 |
if not base_url:
|
| 196 |
base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
|
|
|
|
| 223 |
"top_n": len(texts),
|
| 224 |
}
|
| 225 |
res = requests.post(self.base_url, headers=self.headers, json=data).json()
|
| 226 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 227 |
+
for d in res["rankings"]:
|
| 228 |
+
rank[d["index"]] = d["logit"]
|
| 229 |
+
return rank, token_count
|
| 230 |
|
| 231 |
|
| 232 |
class LmStudioRerank(Base):
|
|
|
|
| 263 |
top_n=len(texts),
|
| 264 |
return_documents=False,
|
| 265 |
)
|
| 266 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 267 |
+
for d in res.results:
|
| 268 |
+
rank[d.index] = d.relevance_score
|
| 269 |
+
return rank, token_count
|
| 270 |
|
| 271 |
|
| 272 |
class TogetherAIRerank(Base):
|
|
|
|
| 279 |
|
| 280 |
class SILICONFLOWRerank(Base):
|
| 281 |
def __init__(
|
| 282 |
+
self, key, model_name, base_url="https://api.siliconflow.cn/v1/rerank"
|
| 283 |
):
|
| 284 |
if not base_url:
|
| 285 |
base_url = "https://api.siliconflow.cn/v1/rerank"
|
|
|
|
| 304 |
response = requests.post(
|
| 305 |
self.base_url, json=payload, headers=self.headers
|
| 306 |
).json()
|
| 307 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 308 |
+
for d in response["results"]:
|
| 309 |
+
rank[d["index"]] = d["relevance_score"]
|
| 310 |
return (
|
| 311 |
+
rank,
|
| 312 |
response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
|
| 313 |
)
|
| 314 |
|
|
|
|
| 330 |
documents=texts,
|
| 331 |
top_n=len(texts),
|
| 332 |
).body
|
| 333 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 334 |
+
for d in res["results"]:
|
| 335 |
+
rank[d["index"]] = d["relevance_score"]
|
| 336 |
+
return rank, res["usage"]["total_tokens"]
|
| 337 |
|
| 338 |
|
| 339 |
class VoyageRerank(Base):
|
|
|
|
| 347 |
res = self.client.rerank(
|
| 348 |
query=query, documents=texts, model=self.model_name, top_k=len(texts)
|
| 349 |
)
|
| 350 |
+
rank = np.zeros(len(texts), dtype=float)
|
| 351 |
+
for r in res.results:
|
| 352 |
+
rank[r.index] = r.relevance_score
|
| 353 |
+
return rank, res.total_tokens
|
sdk/python/ragflow/modules/assistant.py
CHANGED
|
@@ -76,7 +76,7 @@ class Assistant(Base):
|
|
| 76 |
raise Exception(res["retmsg"])
|
| 77 |
|
| 78 |
def get_session(self, id) -> Session:
|
| 79 |
-
res = self.get("/session/get", {"id": id})
|
| 80 |
res = res.json()
|
| 81 |
if res.get("retmsg") == "success":
|
| 82 |
return Session(self.rag, res["data"])
|
|
|
|
| 76 |
raise Exception(res["retmsg"])
|
| 77 |
|
| 78 |
def get_session(self, id) -> Session:
|
| 79 |
+
res = self.get("/session/get", {"id": id,"assistant_id":self.id})
|
| 80 |
res = res.json()
|
| 81 |
if res.get("retmsg") == "success":
|
| 82 |
return Session(self.rag, res["data"])
|
sdk/python/ragflow/modules/session.py
CHANGED
|
@@ -16,9 +16,12 @@ class Session(Base):
|
|
| 16 |
if "reference" in message:
|
| 17 |
message.pop("reference")
|
| 18 |
res = self.post("/session/completion",
|
| 19 |
-
{"
|
| 20 |
for line in res.iter_lines():
|
| 21 |
line = line.decode("utf-8")
|
|
|
|
|
|
|
|
|
|
| 22 |
if line.startswith("data:"):
|
| 23 |
json_data = json.loads(line[5:])
|
| 24 |
if json_data["data"] != True:
|
|
@@ -69,6 +72,7 @@ class Message(Base):
|
|
| 69 |
self.reference = None
|
| 70 |
self.role = "assistant"
|
| 71 |
self.prompt = None
|
|
|
|
| 72 |
super().__init__(rag, res_dict)
|
| 73 |
|
| 74 |
|
|
@@ -76,10 +80,10 @@ class Chunk(Base):
|
|
| 76 |
def __init__(self, rag, res_dict):
|
| 77 |
self.id = None
|
| 78 |
self.content = None
|
| 79 |
-
self.document_id =
|
| 80 |
-
self.document_name =
|
| 81 |
-
self.knowledgebase_id =
|
| 82 |
-
self.image_id =
|
| 83 |
self.similarity = None
|
| 84 |
self.vector_similarity = None
|
| 85 |
self.term_similarity = None
|
|
|
|
| 16 |
if "reference" in message:
|
| 17 |
message.pop("reference")
|
| 18 |
res = self.post("/session/completion",
|
| 19 |
+
{"session_id": self.id, "question": question, "stream": True}, stream=stream)
|
| 20 |
for line in res.iter_lines():
|
| 21 |
line = line.decode("utf-8")
|
| 22 |
+
if line.startswith("{"):
|
| 23 |
+
json_data = json.loads(line)
|
| 24 |
+
raise Exception(json_data["retmsg"])
|
| 25 |
if line.startswith("data:"):
|
| 26 |
json_data = json.loads(line[5:])
|
| 27 |
if json_data["data"] != True:
|
|
|
|
| 72 |
self.reference = None
|
| 73 |
self.role = "assistant"
|
| 74 |
self.prompt = None
|
| 75 |
+
self.id = None
|
| 76 |
super().__init__(rag, res_dict)
|
| 77 |
|
| 78 |
|
|
|
|
| 80 |
def __init__(self, rag, res_dict):
|
| 81 |
self.id = None
|
| 82 |
self.content = None
|
| 83 |
+
self.document_id = ""
|
| 84 |
+
self.document_name = ""
|
| 85 |
+
self.knowledgebase_id = ""
|
| 86 |
+
self.image_id = ""
|
| 87 |
self.similarity = None
|
| 88 |
self.vector_similarity = None
|
| 89 |
self.term_similarity = None
|
sdk/python/test/t_session.py
CHANGED
|
@@ -19,7 +19,7 @@ class TestSession:
|
|
| 19 |
question = "What is AI"
|
| 20 |
for ans in session.chat(question, stream=True):
|
| 21 |
pass
|
| 22 |
-
assert ans.content
|
| 23 |
|
| 24 |
def test_delete_session_with_success(self):
|
| 25 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
|
|
|
| 19 |
question = "What is AI"
|
| 20 |
for ans in session.chat(question, stream=True):
|
| 21 |
pass
|
| 22 |
+
assert not ans.content.startswith("**ERROR**"), "Please check this error."
|
| 23 |
|
| 24 |
def test_delete_session_with_success(self):
|
| 25 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|