KevinHuSh
commited on
Commit
·
4825b73
1
Parent(s):
a3ebd45
add support for mistral (#1153)
Browse files### What problem does this PR solve?
#433
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/db/init_data.py +62 -0
- rag/llm/__init__.py +4 -2
- rag/llm/chat_model.py +54 -0
- rag/llm/embedding_model.py +21 -1
api/db/init_data.py
CHANGED
|
@@ -157,6 +157,11 @@ factory_infos = [{
|
|
| 157 |
"logo": "",
|
| 158 |
"tags": "LLM,TEXT EMBEDDING",
|
| 159 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
}
|
| 161 |
# {
|
| 162 |
# "name": "文心一言",
|
|
@@ -584,6 +589,63 @@ def init_llm_factory():
|
|
| 584 |
"max_tokens": 8192,
|
| 585 |
"model_type": LLMType.CHAT.value
|
| 586 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
]
|
| 588 |
for info in factory_infos:
|
| 589 |
try:
|
|
|
|
| 157 |
"logo": "",
|
| 158 |
"tags": "LLM,TEXT EMBEDDING",
|
| 159 |
"status": "1",
|
| 160 |
+
},{
|
| 161 |
+
"name": "Mistral",
|
| 162 |
+
"logo": "",
|
| 163 |
+
"tags": "LLM,TEXT EMBEDDING",
|
| 164 |
+
"status": "1",
|
| 165 |
}
|
| 166 |
# {
|
| 167 |
# "name": "文心一言",
|
|
|
|
| 589 |
"max_tokens": 8192,
|
| 590 |
"model_type": LLMType.CHAT.value
|
| 591 |
},
|
| 592 |
+
# ------------------------ Mistral -----------------------
|
| 593 |
+
{
|
| 594 |
+
"fid": factory_infos[14]["name"],
|
| 595 |
+
"llm_name": "open-mixtral-8x22b",
|
| 596 |
+
"tags": "LLM,CHAT,64k",
|
| 597 |
+
"max_tokens": 64000,
|
| 598 |
+
"model_type": LLMType.CHAT.value
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"fid": factory_infos[14]["name"],
|
| 602 |
+
"llm_name": "open-mixtral-8x7b",
|
| 603 |
+
"tags": "LLM,CHAT,32k",
|
| 604 |
+
"max_tokens": 32000,
|
| 605 |
+
"model_type": LLMType.CHAT.value
|
| 606 |
+
},
|
| 607 |
+
{
|
| 608 |
+
"fid": factory_infos[14]["name"],
|
| 609 |
+
"llm_name": "open-mistral-7b",
|
| 610 |
+
"tags": "LLM,CHAT,32k",
|
| 611 |
+
"max_tokens": 32000,
|
| 612 |
+
"model_type": LLMType.CHAT.value
|
| 613 |
+
},
|
| 614 |
+
{
|
| 615 |
+
"fid": factory_infos[14]["name"],
|
| 616 |
+
"llm_name": "mistral-large-latest",
|
| 617 |
+
"tags": "LLM,CHAT,32k",
|
| 618 |
+
"max_tokens": 32000,
|
| 619 |
+
"model_type": LLMType.CHAT.value
|
| 620 |
+
},
|
| 621 |
+
{
|
| 622 |
+
"fid": factory_infos[14]["name"],
|
| 623 |
+
"llm_name": "mistral-small-latest",
|
| 624 |
+
"tags": "LLM,CHAT,32k",
|
| 625 |
+
"max_tokens": 32000,
|
| 626 |
+
"model_type": LLMType.CHAT.value
|
| 627 |
+
},
|
| 628 |
+
{
|
| 629 |
+
"fid": factory_infos[14]["name"],
|
| 630 |
+
"llm_name": "mistral-medium-latest",
|
| 631 |
+
"tags": "LLM,CHAT,32k",
|
| 632 |
+
"max_tokens": 32000,
|
| 633 |
+
"model_type": LLMType.CHAT.value
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"fid": factory_infos[14]["name"],
|
| 637 |
+
"llm_name": "codestral-latest",
|
| 638 |
+
"tags": "LLM,CHAT,32k",
|
| 639 |
+
"max_tokens": 32000,
|
| 640 |
+
"model_type": LLMType.CHAT.value
|
| 641 |
+
},
|
| 642 |
+
{
|
| 643 |
+
"fid": factory_infos[14]["name"],
|
| 644 |
+
"llm_name": "mistral-embed",
|
| 645 |
+
"tags": "LLM,CHAT,8k",
|
| 646 |
+
"max_tokens": 8192,
|
| 647 |
+
"model_type": LLMType.EMBEDDING
|
| 648 |
+
},
|
| 649 |
]
|
| 650 |
for info in factory_infos:
|
| 651 |
try:
|
rag/llm/__init__.py
CHANGED
|
@@ -29,7 +29,8 @@ EmbeddingModel = {
|
|
| 29 |
"Youdao": YoudaoEmbed,
|
| 30 |
"BaiChuan": BaiChuanEmbed,
|
| 31 |
"Jina": JinaEmbed,
|
| 32 |
-
"BAAI": DefaultEmbedding
|
|
|
|
| 33 |
}
|
| 34 |
|
| 35 |
|
|
@@ -52,7 +53,8 @@ ChatModel = {
|
|
| 52 |
"Moonshot": MoonshotChat,
|
| 53 |
"DeepSeek": DeepSeekChat,
|
| 54 |
"BaiChuan": BaiChuanChat,
|
| 55 |
-
"MiniMax": MiniMaxChat
|
|
|
|
| 56 |
}
|
| 57 |
|
| 58 |
|
|
|
|
| 29 |
"Youdao": YoudaoEmbed,
|
| 30 |
"BaiChuan": BaiChuanEmbed,
|
| 31 |
"Jina": JinaEmbed,
|
| 32 |
+
"BAAI": DefaultEmbedding,
|
| 33 |
+
"Mistral": MistralEmbed
|
| 34 |
}
|
| 35 |
|
| 36 |
|
|
|
|
| 53 |
"Moonshot": MoonshotChat,
|
| 54 |
"DeepSeek": DeepSeekChat,
|
| 55 |
"BaiChuan": BaiChuanChat,
|
| 56 |
+
"MiniMax": MiniMaxChat,
|
| 57 |
+
"Mistral": MistralChat
|
| 58 |
}
|
| 59 |
|
| 60 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -472,3 +472,57 @@ class MiniMaxChat(Base):
|
|
| 472 |
if not base_url:
|
| 473 |
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
|
| 474 |
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
if not base_url:
|
| 473 |
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
|
| 474 |
super().__init__(key, model_name, base_url)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class MistralChat(Base):
|
| 478 |
+
|
| 479 |
+
def __init__(self, key, model_name, base_url=None):
|
| 480 |
+
from mistralai.client import MistralClient
|
| 481 |
+
self.client = MistralClient(api_key=key)
|
| 482 |
+
self.model_name = model_name
|
| 483 |
+
|
| 484 |
+
def chat(self, system, history, gen_conf):
|
| 485 |
+
if system:
|
| 486 |
+
history.insert(0, {"role": "system", "content": system})
|
| 487 |
+
for k in list(gen_conf.keys()):
|
| 488 |
+
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 489 |
+
del gen_conf[k]
|
| 490 |
+
try:
|
| 491 |
+
response = self.client.chat(
|
| 492 |
+
model=self.model_name,
|
| 493 |
+
messages=history,
|
| 494 |
+
**gen_conf)
|
| 495 |
+
ans = response.choices[0].message.content
|
| 496 |
+
if response.choices[0].finish_reason == "length":
|
| 497 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 498 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 499 |
+
return ans, response.usage.total_tokens
|
| 500 |
+
except openai.APIError as e:
|
| 501 |
+
return "**ERROR**: " + str(e), 0
|
| 502 |
+
|
| 503 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 504 |
+
if system:
|
| 505 |
+
history.insert(0, {"role": "system", "content": system})
|
| 506 |
+
for k in list(gen_conf.keys()):
|
| 507 |
+
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 508 |
+
del gen_conf[k]
|
| 509 |
+
ans = ""
|
| 510 |
+
total_tokens = 0
|
| 511 |
+
try:
|
| 512 |
+
response = self.client.chat_stream(
|
| 513 |
+
model=self.model_name,
|
| 514 |
+
messages=history,
|
| 515 |
+
**gen_conf)
|
| 516 |
+
for resp in response:
|
| 517 |
+
if not resp.choices or not resp.choices[0].delta.content:continue
|
| 518 |
+
ans += resp.choices[0].delta.content
|
| 519 |
+
total_tokens += 1
|
| 520 |
+
if resp.choices[0].finish_reason == "length":
|
| 521 |
+
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 522 |
+
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 523 |
+
yield ans
|
| 524 |
+
|
| 525 |
+
except openai.APIError as e:
|
| 526 |
+
yield ans + "\n**ERROR**: " + str(e)
|
| 527 |
+
|
| 528 |
+
yield total_tokens
|
rag/llm/embedding_model.py
CHANGED
|
@@ -343,4 +343,24 @@ class InfinityEmbed(Base):
|
|
| 343 |
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
| 344 |
# Using the internal tokenizer to encode the texts and get the total
|
| 345 |
# number of tokens
|
| 346 |
-
return self.encode([text])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
|
| 344 |
# Using the internal tokenizer to encode the texts and get the total
|
| 345 |
# number of tokens
|
| 346 |
+
return self.encode([text])
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class MistralEmbed(Base):
|
| 350 |
+
def __init__(self, key, model_name="mistral-embed",
|
| 351 |
+
base_url=None):
|
| 352 |
+
from mistralai.client import MistralClient
|
| 353 |
+
self.client = MistralClient(api_key=key)
|
| 354 |
+
self.model_name = model_name
|
| 355 |
+
|
| 356 |
+
def encode(self, texts: list, batch_size=32):
|
| 357 |
+
texts = [truncate(t, 8196) for t in texts]
|
| 358 |
+
res = self.client.embeddings(input=texts,
|
| 359 |
+
model=self.model_name)
|
| 360 |
+
return np.array([d.embedding for d in res.data]
|
| 361 |
+
), res.usage.total_tokens
|
| 362 |
+
|
| 363 |
+
def encode_queries(self, text):
|
| 364 |
+
res = self.client.embeddings(input=[truncate(text, 8196)],
|
| 365 |
+
model=self.model_name)
|
| 366 |
+
return np.array(res.data[0].embedding), res.usage.total_tokens
|