KevinHuSh
commited on
Commit
·
4d5f1c9
1
Parent(s):
d11e82f
add support for deepseek (#668)
Browse files### What problem does this PR solve?
#666
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- api/db/init_data.py +21 -1
- rag/llm/__init__.py +2 -1
- rag/llm/chat_model.py +20 -51
api/db/init_data.py
CHANGED
|
@@ -123,7 +123,12 @@ factory_infos = [{
|
|
| 123 |
"name": "Youdao",
|
| 124 |
"logo": "",
|
| 125 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
},
|
| 128 |
# {
|
| 129 |
# "name": "文心一言",
|
|
@@ -331,6 +336,21 @@ def init_llm_factory():
|
|
| 331 |
"max_tokens": 512,
|
| 332 |
"model_type": LLMType.EMBEDDING.value
|
| 333 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
]
|
| 335 |
for info in factory_infos:
|
| 336 |
try:
|
|
|
|
| 123 |
"name": "Youdao",
|
| 124 |
"logo": "",
|
| 125 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 126 |
+
"status": "1",
|
| 127 |
+
},{
|
| 128 |
+
"name": "DeepSeek",
|
| 129 |
+
"logo": "",
|
| 130 |
+
"tags": "LLM",
|
| 131 |
+
"status": "1",
|
| 132 |
},
|
| 133 |
# {
|
| 134 |
# "name": "文心一言",
|
|
|
|
| 336 |
"max_tokens": 512,
|
| 337 |
"model_type": LLMType.EMBEDDING.value
|
| 338 |
},
|
| 339 |
+
# ------------------------ DeepSeek -----------------------
|
| 340 |
+
{
|
| 341 |
+
"fid": factory_infos[8]["name"],
|
| 342 |
+
"llm_name": "deepseek-chat",
|
| 343 |
+
"tags": "LLM,CHAT,",
|
| 344 |
+
"max_tokens": 32768,
|
| 345 |
+
"model_type": LLMType.CHAT.value
|
| 346 |
+
},
|
| 347 |
+
{
|
| 348 |
+
"fid": factory_infos[8]["name"],
|
| 349 |
+
"llm_name": "deepseek-coder",
|
| 350 |
+
"tags": "LLM,CHAT,",
|
| 351 |
+
"max_tokens": 16385,
|
| 352 |
+
"model_type": LLMType.CHAT.value
|
| 353 |
+
},
|
| 354 |
]
|
| 355 |
for info in factory_infos:
|
| 356 |
try:
|
rag/llm/__init__.py
CHANGED
|
@@ -45,6 +45,7 @@ ChatModel = {
|
|
| 45 |
"Tongyi-Qianwen": QWenChat,
|
| 46 |
"Ollama": OllamaChat,
|
| 47 |
"Xinference": XinferenceChat,
|
| 48 |
-
"Moonshot": MoonshotChat
|
|
|
|
| 49 |
}
|
| 50 |
|
|
|
|
| 45 |
"Tongyi-Qianwen": QWenChat,
|
| 46 |
"Ollama": OllamaChat,
|
| 47 |
"Xinference": XinferenceChat,
|
| 48 |
+
"Moonshot": MoonshotChat,
|
| 49 |
+
"DeepSeek": DeepSeekChat
|
| 50 |
}
|
| 51 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -24,16 +24,7 @@ from rag.utils import num_tokens_from_string
|
|
| 24 |
|
| 25 |
|
| 26 |
class Base(ABC):
|
| 27 |
-
def __init__(self, key, model_name):
|
| 28 |
-
pass
|
| 29 |
-
|
| 30 |
-
def chat(self, system, history, gen_conf):
|
| 31 |
-
raise NotImplementedError("Please implement encode method!")
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class GptTurbo(Base):
|
| 35 |
-
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
| 36 |
-
if not base_url: base_url="https://api.openai.com/v1"
|
| 37 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 38 |
self.model_name = model_name
|
| 39 |
|
|
@@ -54,28 +45,28 @@ class GptTurbo(Base):
|
|
| 54 |
return "**ERROR**: " + str(e), 0
|
| 55 |
|
| 56 |
|
| 57 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
| 59 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
| 60 |
-
|
| 61 |
-
api_key=key, base_url=base_url)
|
| 62 |
-
self.model_name = model_name
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 76 |
-
return ans, response.usage.total_tokens
|
| 77 |
-
except openai.APIError as e:
|
| 78 |
-
return "**ERROR**: " + str(e), 0
|
| 79 |
|
| 80 |
|
| 81 |
class QWenChat(Base):
|
|
@@ -157,25 +148,3 @@ class OllamaChat(Base):
|
|
| 157 |
except Exception as e:
|
| 158 |
return "**ERROR**: " + str(e), 0
|
| 159 |
|
| 160 |
-
|
| 161 |
-
class XinferenceChat(Base):
|
| 162 |
-
def __init__(self, key=None, model_name="", base_url=""):
|
| 163 |
-
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
| 164 |
-
self.model_name = model_name
|
| 165 |
-
|
| 166 |
-
def chat(self, system, history, gen_conf):
|
| 167 |
-
if system:
|
| 168 |
-
history.insert(0, {"role": "system", "content": system})
|
| 169 |
-
try:
|
| 170 |
-
response = self.client.chat.completions.create(
|
| 171 |
-
model=self.model_name,
|
| 172 |
-
messages=history,
|
| 173 |
-
**gen_conf)
|
| 174 |
-
ans = response.choices[0].message.content.strip()
|
| 175 |
-
if response.choices[0].finish_reason == "length":
|
| 176 |
-
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
| 177 |
-
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
| 178 |
-
return ans, response.usage.total_tokens
|
| 179 |
-
except openai.APIError as e:
|
| 180 |
-
return "**ERROR**: " + str(e), 0
|
| 181 |
-
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class Base(ABC):
|
| 27 |
+
def __init__(self, key, model_name, base_url):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
self.client = OpenAI(api_key=key, base_url=base_url)
|
| 29 |
self.model_name = model_name
|
| 30 |
|
|
|
|
| 45 |
return "**ERROR**: " + str(e), 0
|
| 46 |
|
| 47 |
|
| 48 |
+
class GptTurbo(Base):
|
| 49 |
+
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
| 50 |
+
if not base_url: base_url="https://api.openai.com/v1"
|
| 51 |
+
super().__init__(key, model_name, base_url)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MoonshotChat(Base):
|
| 55 |
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
| 56 |
if not base_url: base_url="https://api.moonshot.cn/v1"
|
| 57 |
+
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
|
| 60 |
+
class XinferenceChat(Base):
|
| 61 |
+
def __init__(self, key=None, model_name="", base_url=""):
|
| 62 |
+
key = "xxx"
|
| 63 |
+
super().__init__(key, model_name, base_url)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class DeepSeekChat(Base):
|
| 67 |
+
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
| 68 |
+
if not base_url: base_url="https://api.deepseek.com/v1"
|
| 69 |
+
super().__init__(key, model_name, base_url)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
class QWenChat(Base):
|
|
|
|
| 148 |
except Exception as e:
|
| 149 |
return "**ERROR**: " + str(e), 0
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|