Commit
路
d919631
1
Parent(s):
4d2f593
Change default error message to English (#3838)
Browse files### What problem does this PR solve?
As title
### Type of change
- [x] Refactoring
---------
Signed-off-by: Jin Hai <[email protected]>
- rag/llm/chat_model.py +55 -27
- rag/nlp/__init__.py +8 -0
rag/llm/chat_model.py
CHANGED
|
@@ -22,7 +22,7 @@ from abc import ABC
|
|
| 22 |
from openai import OpenAI
|
| 23 |
import openai
|
| 24 |
from ollama import Client
|
| 25 |
-
from rag.nlp import
|
| 26 |
from rag.utils import num_tokens_from_string
|
| 27 |
from groq import Groq
|
| 28 |
import os
|
|
@@ -30,6 +30,8 @@ import json
|
|
| 30 |
import requests
|
| 31 |
import asyncio
|
| 32 |
|
|
|
|
|
|
|
| 33 |
|
| 34 |
class Base(ABC):
|
| 35 |
def __init__(self, key, model_name, base_url):
|
|
@@ -47,8 +49,10 @@ class Base(ABC):
|
|
| 47 |
**gen_conf)
|
| 48 |
ans = response.choices[0].message.content.strip()
|
| 49 |
if response.choices[0].finish_reason == "length":
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
return ans, response.usage.total_tokens
|
| 53 |
except openai.APIError as e:
|
| 54 |
return "**ERROR**: " + str(e), 0
|
|
@@ -80,8 +84,10 @@ class Base(ABC):
|
|
| 80 |
else: total_tokens = resp.usage.total_tokens
|
| 81 |
|
| 82 |
if resp.choices[0].finish_reason == "length":
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
| 85 |
yield ans
|
| 86 |
|
| 87 |
except openai.APIError as e:
|
|
@@ -167,8 +173,10 @@ class BaiChuanChat(Base):
|
|
| 167 |
**self._format_params(gen_conf))
|
| 168 |
ans = response.choices[0].message.content.strip()
|
| 169 |
if response.choices[0].finish_reason == "length":
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
return ans, response.usage.total_tokens
|
| 173 |
except openai.APIError as e:
|
| 174 |
return "**ERROR**: " + str(e), 0
|
|
@@ -207,8 +215,10 @@ class BaiChuanChat(Base):
|
|
| 207 |
else resp.usage["total_tokens"]
|
| 208 |
)
|
| 209 |
if resp.choices[0].finish_reason == "length":
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
| 212 |
yield ans
|
| 213 |
|
| 214 |
except Exception as e:
|
|
@@ -242,8 +252,10 @@ class QWenChat(Base):
|
|
| 242 |
ans += response.output.choices[0]['message']['content']
|
| 243 |
tk_count += response.usage.total_tokens
|
| 244 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
| 247 |
return ans, tk_count
|
| 248 |
|
| 249 |
return "**ERROR**: " + response.message, tk_count
|
|
@@ -276,8 +288,10 @@ class QWenChat(Base):
|
|
| 276 |
ans = resp.output.choices[0]['message']['content']
|
| 277 |
tk_count = resp.usage.total_tokens
|
| 278 |
if resp.output.choices[0].get("finish_reason", "") == "length":
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
| 281 |
yield ans
|
| 282 |
else:
|
| 283 |
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
@@ -308,8 +322,10 @@ class ZhipuChat(Base):
|
|
| 308 |
)
|
| 309 |
ans = response.choices[0].message.content.strip()
|
| 310 |
if response.choices[0].finish_reason == "length":
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
| 313 |
return ans, response.usage.total_tokens
|
| 314 |
except Exception as e:
|
| 315 |
return "**ERROR**: " + str(e), 0
|
|
@@ -333,8 +349,10 @@ class ZhipuChat(Base):
|
|
| 333 |
delta = resp.choices[0].delta.content
|
| 334 |
ans += delta
|
| 335 |
if resp.choices[0].finish_reason == "length":
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
| 338 |
tk_count = resp.usage.total_tokens
|
| 339 |
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
| 340 |
yield ans
|
|
@@ -525,8 +543,10 @@ class MiniMaxChat(Base):
|
|
| 525 |
response = response.json()
|
| 526 |
ans = response["choices"][0]["message"]["content"].strip()
|
| 527 |
if response["choices"][0]["finish_reason"] == "length":
|
| 528 |
-
|
| 529 |
-
|
|
|
|
|
|
|
| 530 |
return ans, response["usage"]["total_tokens"]
|
| 531 |
except Exception as e:
|
| 532 |
return "**ERROR**: " + str(e), 0
|
|
@@ -594,8 +614,10 @@ class MistralChat(Base):
|
|
| 594 |
**gen_conf)
|
| 595 |
ans = response.choices[0].message.content
|
| 596 |
if response.choices[0].finish_reason == "length":
|
| 597 |
-
|
| 598 |
-
|
|
|
|
|
|
|
| 599 |
return ans, response.usage.total_tokens
|
| 600 |
except openai.APIError as e:
|
| 601 |
return "**ERROR**: " + str(e), 0
|
|
@@ -618,8 +640,10 @@ class MistralChat(Base):
|
|
| 618 |
ans += resp.choices[0].delta.content
|
| 619 |
total_tokens += 1
|
| 620 |
if resp.choices[0].finish_reason == "length":
|
| 621 |
-
|
| 622 |
-
|
|
|
|
|
|
|
| 623 |
yield ans
|
| 624 |
|
| 625 |
except openai.APIError as e:
|
|
@@ -811,8 +835,10 @@ class GroqChat:
|
|
| 811 |
)
|
| 812 |
ans = response.choices[0].message.content
|
| 813 |
if response.choices[0].finish_reason == "length":
|
| 814 |
-
|
| 815 |
-
|
|
|
|
|
|
|
| 816 |
return ans, response.usage.total_tokens
|
| 817 |
except Exception as e:
|
| 818 |
return ans + "\n**ERROR**: " + str(e), 0
|
|
@@ -838,8 +864,10 @@ class GroqChat:
|
|
| 838 |
ans += resp.choices[0].delta.content
|
| 839 |
total_tokens += 1
|
| 840 |
if resp.choices[0].finish_reason == "length":
|
| 841 |
-
|
| 842 |
-
|
|
|
|
|
|
|
| 843 |
yield ans
|
| 844 |
|
| 845 |
except Exception as e:
|
|
|
|
| 22 |
from openai import OpenAI
|
| 23 |
import openai
|
| 24 |
from ollama import Client
|
| 25 |
+
from rag.nlp import is_chinese
|
| 26 |
from rag.utils import num_tokens_from_string
|
| 27 |
from groq import Groq
|
| 28 |
import os
|
|
|
|
| 30 |
import requests
|
| 31 |
import asyncio
|
| 32 |
|
| 33 |
+
LENGTH_NOTIFICATION_CN = "路路路路路路\n鐢变簬闀垮害鐨勫師鍥狅紝鍥炵瓟琚埅鏂簡锛岃缁х画鍚楋紵"
|
| 34 |
+
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
|
| 35 |
|
| 36 |
class Base(ABC):
|
| 37 |
def __init__(self, key, model_name, base_url):
|
|
|
|
| 49 |
**gen_conf)
|
| 50 |
ans = response.choices[0].message.content.strip()
|
| 51 |
if response.choices[0].finish_reason == "length":
|
| 52 |
+
if is_chinese(ans):
|
| 53 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 54 |
+
else:
|
| 55 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 56 |
return ans, response.usage.total_tokens
|
| 57 |
except openai.APIError as e:
|
| 58 |
return "**ERROR**: " + str(e), 0
|
|
|
|
| 84 |
else: total_tokens = resp.usage.total_tokens
|
| 85 |
|
| 86 |
if resp.choices[0].finish_reason == "length":
|
| 87 |
+
if is_chinese(ans):
|
| 88 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 89 |
+
else:
|
| 90 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 91 |
yield ans
|
| 92 |
|
| 93 |
except openai.APIError as e:
|
|
|
|
| 173 |
**self._format_params(gen_conf))
|
| 174 |
ans = response.choices[0].message.content.strip()
|
| 175 |
if response.choices[0].finish_reason == "length":
|
| 176 |
+
if is_chinese([ans]):
|
| 177 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 178 |
+
else:
|
| 179 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 180 |
return ans, response.usage.total_tokens
|
| 181 |
except openai.APIError as e:
|
| 182 |
return "**ERROR**: " + str(e), 0
|
|
|
|
| 215 |
else resp.usage["total_tokens"]
|
| 216 |
)
|
| 217 |
if resp.choices[0].finish_reason == "length":
|
| 218 |
+
if is_chinese([ans]):
|
| 219 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 220 |
+
else:
|
| 221 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 222 |
yield ans
|
| 223 |
|
| 224 |
except Exception as e:
|
|
|
|
| 252 |
ans += response.output.choices[0]['message']['content']
|
| 253 |
tk_count += response.usage.total_tokens
|
| 254 |
if response.output.choices[0].get("finish_reason", "") == "length":
|
| 255 |
+
if is_chinese([ans]):
|
| 256 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 257 |
+
else:
|
| 258 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 259 |
return ans, tk_count
|
| 260 |
|
| 261 |
return "**ERROR**: " + response.message, tk_count
|
|
|
|
| 288 |
ans = resp.output.choices[0]['message']['content']
|
| 289 |
tk_count = resp.usage.total_tokens
|
| 290 |
if resp.output.choices[0].get("finish_reason", "") == "length":
|
| 291 |
+
if is_chinese(ans):
|
| 292 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 293 |
+
else:
|
| 294 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 295 |
yield ans
|
| 296 |
else:
|
| 297 |
yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**"
|
|
|
|
| 322 |
)
|
| 323 |
ans = response.choices[0].message.content.strip()
|
| 324 |
if response.choices[0].finish_reason == "length":
|
| 325 |
+
if is_chinese(ans):
|
| 326 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 327 |
+
else:
|
| 328 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 329 |
return ans, response.usage.total_tokens
|
| 330 |
except Exception as e:
|
| 331 |
return "**ERROR**: " + str(e), 0
|
|
|
|
| 349 |
delta = resp.choices[0].delta.content
|
| 350 |
ans += delta
|
| 351 |
if resp.choices[0].finish_reason == "length":
|
| 352 |
+
if is_chinese(ans):
|
| 353 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 354 |
+
else:
|
| 355 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 356 |
tk_count = resp.usage.total_tokens
|
| 357 |
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
|
| 358 |
yield ans
|
|
|
|
| 543 |
response = response.json()
|
| 544 |
ans = response["choices"][0]["message"]["content"].strip()
|
| 545 |
if response["choices"][0]["finish_reason"] == "length":
|
| 546 |
+
if is_chinese(ans):
|
| 547 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 548 |
+
else:
|
| 549 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 550 |
return ans, response["usage"]["total_tokens"]
|
| 551 |
except Exception as e:
|
| 552 |
return "**ERROR**: " + str(e), 0
|
|
|
|
| 614 |
**gen_conf)
|
| 615 |
ans = response.choices[0].message.content
|
| 616 |
if response.choices[0].finish_reason == "length":
|
| 617 |
+
if is_chinese(ans):
|
| 618 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 619 |
+
else:
|
| 620 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 621 |
return ans, response.usage.total_tokens
|
| 622 |
except openai.APIError as e:
|
| 623 |
return "**ERROR**: " + str(e), 0
|
|
|
|
| 640 |
ans += resp.choices[0].delta.content
|
| 641 |
total_tokens += 1
|
| 642 |
if resp.choices[0].finish_reason == "length":
|
| 643 |
+
if is_chinese(ans):
|
| 644 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 645 |
+
else:
|
| 646 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 647 |
yield ans
|
| 648 |
|
| 649 |
except openai.APIError as e:
|
|
|
|
| 835 |
)
|
| 836 |
ans = response.choices[0].message.content
|
| 837 |
if response.choices[0].finish_reason == "length":
|
| 838 |
+
if is_chinese(ans):
|
| 839 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 840 |
+
else:
|
| 841 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 842 |
return ans, response.usage.total_tokens
|
| 843 |
except Exception as e:
|
| 844 |
return ans + "\n**ERROR**: " + str(e), 0
|
|
|
|
| 864 |
ans += resp.choices[0].delta.content
|
| 865 |
total_tokens += 1
|
| 866 |
if resp.choices[0].finish_reason == "length":
|
| 867 |
+
if is_chinese(ans):
|
| 868 |
+
ans += LENGTH_NOTIFICATION_CN
|
| 869 |
+
else:
|
| 870 |
+
ans += LENGTH_NOTIFICATION_EN
|
| 871 |
yield ans
|
| 872 |
|
| 873 |
except Exception as e:
|
rag/nlp/__init__.py
CHANGED
|
@@ -230,6 +230,14 @@ def is_english(texts):
|
|
| 230 |
return True
|
| 231 |
return False
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
def tokenize(d, t, eng):
|
| 235 |
d["content_with_weight"] = t
|
|
|
|
| 230 |
return True
|
| 231 |
return False
|
| 232 |
|
| 233 |
+
def is_chinese(text):
|
| 234 |
+
chinese = 0
|
| 235 |
+
for ch in text:
|
| 236 |
+
if '\u4e00' <= ch <= '\u9fff':
|
| 237 |
+
chinese += 1
|
| 238 |
+
if chinese / len(text) > 0.2:
|
| 239 |
+
return True
|
| 240 |
+
return False
|
| 241 |
|
| 242 |
def tokenize(d, t, eng):
|
| 243 |
d["content_with_weight"] = t
|