:gem: [Feature] Support new model: openchat-3.5-0106
Browse files- apis/chat_api.py +10 -3
- messagers/message_composer.py +30 -22
- networks/message_streamer.py +4 -2
- requirements.txt +1 -0
apis/chat_api.py
CHANGED
|
@@ -40,6 +40,13 @@ class ChatAPIApp:
|
|
| 40 |
"created": 1700000000,
|
| 41 |
"owned_by": "mistralai",
|
| 42 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
{
|
| 44 |
"id": "mistral-7b",
|
| 45 |
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
|
|
@@ -48,11 +55,11 @@ class ChatAPIApp:
|
|
| 48 |
"owned_by": "mistralai",
|
| 49 |
},
|
| 50 |
{
|
| 51 |
-
"id": "
|
| 52 |
-
"description": "[
|
| 53 |
"object": "model",
|
| 54 |
"created": 1700000000,
|
| 55 |
-
"owned_by": "
|
| 56 |
},
|
| 57 |
{
|
| 58 |
"id": "gemma-7b",
|
|
|
|
| 40 |
"created": 1700000000,
|
| 41 |
"owned_by": "mistralai",
|
| 42 |
},
|
| 43 |
+
{
|
| 44 |
+
"id": "nous-mixtral-8x7b",
|
| 45 |
+
"description": "[NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO]: https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
| 46 |
+
"object": "model",
|
| 47 |
+
"created": 1700000000,
|
| 48 |
+
"owned_by": "NousResearch",
|
| 49 |
+
},
|
| 50 |
{
|
| 51 |
"id": "mistral-7b",
|
| 52 |
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
|
|
|
|
| 55 |
"owned_by": "mistralai",
|
| 56 |
},
|
| 57 |
{
|
| 58 |
+
"id": "openchat-3.5",
|
| 59 |
+
"description": "[openchat/openchat-3.5-0106]: https://huggingface.co/openchat/openchat-3.5-0106",
|
| 60 |
"object": "model",
|
| 61 |
"created": 1700000000,
|
| 62 |
+
"owned_by": "openchat",
|
| 63 |
},
|
| 64 |
{
|
| 65 |
"id": "gemma-7b",
|
messagers/message_composer.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import re
|
| 2 |
from pprint import pprint
|
| 3 |
from utils.logger import logger
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
class MessageComposer:
|
|
@@ -8,8 +9,8 @@ class MessageComposer:
|
|
| 8 |
AVALAIBLE_MODELS = [
|
| 9 |
"mixtral-8x7b",
|
| 10 |
"mistral-7b",
|
| 11 |
-
"openchat-3.5",
|
| 12 |
"nous-mixtral-8x7b",
|
|
|
|
| 13 |
"gemma-7b",
|
| 14 |
]
|
| 15 |
|
|
@@ -102,26 +103,30 @@ class MessageComposer:
|
|
| 102 |
self.merged_str = "\n".join(self.merged_str_list)
|
| 103 |
# https://huggingface.co/openchat/openchat-3.5-0106
|
| 104 |
elif self.model in ["openchat-3.5"]:
|
| 105 |
-
|
| 106 |
-
self.
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# https://huggingface.co/google/gemma-7b-it#chat-template
|
| 126 |
elif self.model in ["gemma-7b"]:
|
| 127 |
self.messages = self.concat_messages_by_role(messages)
|
|
@@ -265,7 +270,8 @@ class MessageComposer:
|
|
| 265 |
if __name__ == "__main__":
|
| 266 |
# model = "mixtral-8x7b"
|
| 267 |
# model = "nous-mixtral-8x7b"
|
| 268 |
-
model = "gemma-7b"
|
|
|
|
| 269 |
composer = MessageComposer(model)
|
| 270 |
messages = [
|
| 271 |
{
|
|
@@ -291,3 +297,5 @@ if __name__ == "__main__":
|
|
| 291 |
pprint(composer.split(merged_str))
|
| 292 |
# logger.note("merged merged_str:")
|
| 293 |
# logger.mesg(composer.merge(composer.split(merged_str)))
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
from pprint import pprint
|
| 3 |
from utils.logger import logger
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
|
| 6 |
|
| 7 |
class MessageComposer:
|
|
|
|
| 9 |
AVALAIBLE_MODELS = [
|
| 10 |
"mixtral-8x7b",
|
| 11 |
"mistral-7b",
|
|
|
|
| 12 |
"nous-mixtral-8x7b",
|
| 13 |
+
"openchat-3.5",
|
| 14 |
"gemma-7b",
|
| 15 |
]
|
| 16 |
|
|
|
|
| 103 |
self.merged_str = "\n".join(self.merged_str_list)
|
| 104 |
# https://huggingface.co/openchat/openchat-3.5-0106
|
| 105 |
elif self.model in ["openchat-3.5"]:
|
| 106 |
+
tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
|
| 107 |
+
self.merged_str = tokenizer.apply_chat_template(
|
| 108 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 109 |
+
)
|
| 110 |
+
# self.messages = self.concat_messages_by_role(messages)
|
| 111 |
+
# self.merged_str_list = []
|
| 112 |
+
# self.end_of_turn = "<|end_of_turn|>"
|
| 113 |
+
# for message in self.messages:
|
| 114 |
+
# role = message["role"]
|
| 115 |
+
# content = message["content"]
|
| 116 |
+
# if role in self.inst_roles:
|
| 117 |
+
# self.merged_str_list.append(
|
| 118 |
+
# f"GPT4 Correct User:\n{content}{self.end_of_turn}"
|
| 119 |
+
# )
|
| 120 |
+
# elif role in self.answer_roles:
|
| 121 |
+
# self.merged_str_list.append(
|
| 122 |
+
# f"GPT4 Correct Assistant:\n{content}{self.end_of_turn}"
|
| 123 |
+
# )
|
| 124 |
+
# else:
|
| 125 |
+
# self.merged_str_list.append(
|
| 126 |
+
# f"GPT4 Correct User: {content}{self.end_of_turn}"
|
| 127 |
+
# )
|
| 128 |
+
# self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
|
| 129 |
+
# self.merged_str = "\n".join(self.merged_str_list)
|
| 130 |
# https://huggingface.co/google/gemma-7b-it#chat-template
|
| 131 |
elif self.model in ["gemma-7b"]:
|
| 132 |
self.messages = self.concat_messages_by_role(messages)
|
|
|
|
| 270 |
if __name__ == "__main__":
|
| 271 |
# model = "mixtral-8x7b"
|
| 272 |
# model = "nous-mixtral-8x7b"
|
| 273 |
+
# model = "gemma-7b"
|
| 274 |
+
model = "openchat-3.5"
|
| 275 |
composer = MessageComposer(model)
|
| 276 |
messages = [
|
| 277 |
{
|
|
|
|
| 297 |
pprint(composer.split(merged_str))
|
| 298 |
# logger.note("merged merged_str:")
|
| 299 |
# logger.mesg(composer.merge(composer.split(merged_str)))
|
| 300 |
+
|
| 301 |
+
# python -m messagers.message_composer
|
networks/message_streamer.py
CHANGED
|
@@ -5,6 +5,7 @@ from tiktoken import get_encoding as tiktoken_get_encoding
|
|
| 5 |
from messagers.message_outputer import OpenaiStreamOutputer
|
| 6 |
from utils.logger import logger
|
| 7 |
from utils.enver import enver
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class MessageStreamer:
|
|
@@ -12,8 +13,8 @@ class MessageStreamer:
|
|
| 12 |
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
|
| 13 |
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
|
| 14 |
"nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
|
|
|
| 15 |
"gemma-7b": "google/gemma-7b-it",
|
| 16 |
-
# "openchat-3.5": "openchat/openchat-3.5-1210", # 68.89, fast
|
| 17 |
# "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow
|
| 18 |
# "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User
|
| 19 |
# "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score
|
|
@@ -43,7 +44,8 @@ class MessageStreamer:
|
|
| 43 |
self.model = "default"
|
| 44 |
self.model_fullname = self.MODEL_MAP[self.model]
|
| 45 |
self.message_outputer = OpenaiStreamOutputer()
|
| 46 |
-
self.tokenizer = tiktoken_get_encoding("cl100k_base")
|
|
|
|
| 47 |
|
| 48 |
def parse_line(self, line):
|
| 49 |
line = line.decode("utf-8")
|
|
|
|
| 5 |
from messagers.message_outputer import OpenaiStreamOutputer
|
| 6 |
from utils.logger import logger
|
| 7 |
from utils.enver import enver
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
|
| 10 |
|
| 11 |
class MessageStreamer:
|
|
|
|
| 13 |
"mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended]
|
| 14 |
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast
|
| 15 |
"nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
|
| 16 |
+
"openchat-3.5": "openchat/openchat-3.5-0106",
|
| 17 |
"gemma-7b": "google/gemma-7b-it",
|
|
|
|
| 18 |
# "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow
|
| 19 |
# "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User
|
| 20 |
# "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score
|
|
|
|
| 44 |
self.model = "default"
|
| 45 |
self.model_fullname = self.MODEL_MAP[self.model]
|
| 46 |
self.message_outputer = OpenaiStreamOutputer()
|
| 47 |
+
# self.tokenizer = tiktoken_get_encoding("cl100k_base")
|
| 48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)
|
| 49 |
|
| 50 |
def parse_line(self, line):
|
| 51 |
line = line.decode("utf-8")
|
requirements.txt
CHANGED
|
@@ -8,5 +8,6 @@ requests
|
|
| 8 |
sse_starlette
|
| 9 |
termcolor
|
| 10 |
tiktoken
|
|
|
|
| 11 |
uvicorn
|
| 12 |
websockets
|
|
|
|
| 8 |
sse_starlette
|
| 9 |
termcolor
|
| 10 |
tiktoken
|
| 11 |
+
transformers
|
| 12 |
uvicorn
|
| 13 |
websockets
|