H
KevinHuSh
commited on
Commit
·
dffdcde
1
Parent(s):
255441a
Add Support for AWS Bedrock (#1408)
Browse files### What problem does this PR solve?
#308
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: KevinHuSh <[email protected]>
- api/apps/llm_app.py +14 -4
- api/db/init_data.py +169 -1
- rag/llm/__init__.py +4 -2
- rag/llm/chat_model.py +87 -0
- rag/llm/embedding_model.py +45 -0
- requirements.txt +2 -0
- requirements_arm.txt +2 -0
- requirements_dev.txt +2 -0
api/apps/llm_app.py
CHANGED
|
@@ -109,15 +109,23 @@ def set_api_key():
|
|
| 109 |
def add_llm():
|
| 110 |
req = request.json
|
| 111 |
factory = req["llm_factory"]
|
| 112 |
-
|
| 113 |
-
# Assemble volc_ak, volc_sk, endpoint_id into api_key
|
| 114 |
if factory == "VolcEngine":
|
|
|
|
|
|
|
| 115 |
temp = list(eval(req["llm_name"]).items())[0]
|
| 116 |
llm_name = temp[0]
|
| 117 |
endpoint_id = temp[1]
|
| 118 |
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
|
| 119 |
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
|
| 120 |
f'"ep_id": "{endpoint_id}", ' + '}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
llm_name = req["llm_name"]
|
| 123 |
api_key = "xxxxxxxxxxxxxxx"
|
|
@@ -134,7 +142,9 @@ def add_llm():
|
|
| 134 |
msg = ""
|
| 135 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
| 136 |
mdl = EmbeddingModel[factory](
|
| 137 |
-
key=
|
|
|
|
|
|
|
| 138 |
try:
|
| 139 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
| 140 |
if len(arr[0]) == 0 or tc == 0:
|
|
@@ -143,7 +153,7 @@ def add_llm():
|
|
| 143 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
| 144 |
elif llm["model_type"] == LLMType.CHAT.value:
|
| 145 |
mdl = ChatModel[factory](
|
| 146 |
-
key=llm['api_key'] if factory
|
| 147 |
model_name=llm["llm_name"],
|
| 148 |
base_url=llm["api_base"]
|
| 149 |
)
|
|
|
|
| 109 |
def add_llm():
|
| 110 |
req = request.json
|
| 111 |
factory = req["llm_factory"]
|
| 112 |
+
|
|
|
|
| 113 |
if factory == "VolcEngine":
|
| 114 |
+
# For VolcEngine, due to its special authentication method
|
| 115 |
+
# Assemble volc_ak, volc_sk, endpoint_id into api_key
|
| 116 |
temp = list(eval(req["llm_name"]).items())[0]
|
| 117 |
llm_name = temp[0]
|
| 118 |
endpoint_id = temp[1]
|
| 119 |
api_key = '{' + f'"volc_ak": "{req.get("volc_ak", "")}", ' \
|
| 120 |
f'"volc_sk": "{req.get("volc_sk", "")}", ' \
|
| 121 |
f'"ep_id": "{endpoint_id}", ' + '}'
|
| 122 |
+
elif factory == "Bedrock":
|
| 123 |
+
# For Bedrock, due to its special authentication method
|
| 124 |
+
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
| 125 |
+
llm_name = req["llm_name"]
|
| 126 |
+
api_key = '{' + f'"bedrock_ak": "{req.get("bedrock_ak", "")}", ' \
|
| 127 |
+
f'"bedrock_sk": "{req.get("bedrock_sk", "")}", ' \
|
| 128 |
+
f'"bedrock_region": "{req.get("bedrock_region", "")}", ' + '}'
|
| 129 |
else:
|
| 130 |
llm_name = req["llm_name"]
|
| 131 |
api_key = "xxxxxxxxxxxxxxx"
|
|
|
|
| 142 |
msg = ""
|
| 143 |
if llm["model_type"] == LLMType.EMBEDDING.value:
|
| 144 |
mdl = EmbeddingModel[factory](
|
| 145 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
|
| 146 |
+
model_name=llm["llm_name"],
|
| 147 |
+
base_url=llm["api_base"])
|
| 148 |
try:
|
| 149 |
arr, tc = mdl.encode(["Test if the api key is available"])
|
| 150 |
if len(arr[0]) == 0 or tc == 0:
|
|
|
|
| 153 |
msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
|
| 154 |
elif llm["model_type"] == LLMType.CHAT.value:
|
| 155 |
mdl = ChatModel[factory](
|
| 156 |
+
key=llm['api_key'] if factory in ["VolcEngine", "Bedrock"] else None,
|
| 157 |
model_name=llm["llm_name"],
|
| 158 |
base_url=llm["api_base"]
|
| 159 |
)
|
api/db/init_data.py
CHANGED
|
@@ -170,6 +170,11 @@ factory_infos = [{
|
|
| 170 |
"logo": "",
|
| 171 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 172 |
"status": "1",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
}
|
| 174 |
# {
|
| 175 |
# "name": "文心一言",
|
|
@@ -730,7 +735,170 @@ def init_llm_factory():
|
|
| 730 |
"max_tokens": 765,
|
| 731 |
"model_type": LLMType.IMAGE2TEXT.value
|
| 732 |
},
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
]
|
| 735 |
for info in factory_infos:
|
| 736 |
try:
|
|
|
|
| 170 |
"logo": "",
|
| 171 |
"tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
|
| 172 |
"status": "1",
|
| 173 |
+
},{
|
| 174 |
+
"name": "Bedrock",
|
| 175 |
+
"logo": "",
|
| 176 |
+
"tags": "LLM,TEXT EMBEDDING",
|
| 177 |
+
"status": "1",
|
| 178 |
}
|
| 179 |
# {
|
| 180 |
# "name": "文心一言",
|
|
|
|
| 735 |
"max_tokens": 765,
|
| 736 |
"model_type": LLMType.IMAGE2TEXT.value
|
| 737 |
},
|
| 738 |
+
# ------------------------ Bedrock -----------------------
|
| 739 |
+
{
|
| 740 |
+
"fid": factory_infos[16]["name"],
|
| 741 |
+
"llm_name": "ai21.j2-ultra-v1",
|
| 742 |
+
"tags": "LLM,CHAT,8k",
|
| 743 |
+
"max_tokens": 8191,
|
| 744 |
+
"model_type": LLMType.CHAT.value
|
| 745 |
+
}, {
|
| 746 |
+
"fid": factory_infos[16]["name"],
|
| 747 |
+
"llm_name": "ai21.j2-mid-v1",
|
| 748 |
+
"tags": "LLM,CHAT,8k",
|
| 749 |
+
"max_tokens": 8191,
|
| 750 |
+
"model_type": LLMType.CHAT.value
|
| 751 |
+
}, {
|
| 752 |
+
"fid": factory_infos[16]["name"],
|
| 753 |
+
"llm_name": "cohere.command-text-v14",
|
| 754 |
+
"tags": "LLM,CHAT,4k",
|
| 755 |
+
"max_tokens": 4096,
|
| 756 |
+
"model_type": LLMType.CHAT.value
|
| 757 |
+
}, {
|
| 758 |
+
"fid": factory_infos[16]["name"],
|
| 759 |
+
"llm_name": "cohere.command-light-text-v14",
|
| 760 |
+
"tags": "LLM,CHAT,4k",
|
| 761 |
+
"max_tokens": 4096,
|
| 762 |
+
"model_type": LLMType.CHAT.value
|
| 763 |
+
}, {
|
| 764 |
+
"fid": factory_infos[16]["name"],
|
| 765 |
+
"llm_name": "cohere.command-r-v1:0",
|
| 766 |
+
"tags": "LLM,CHAT,128k",
|
| 767 |
+
"max_tokens": 128 * 1024,
|
| 768 |
+
"model_type": LLMType.CHAT.value
|
| 769 |
+
}, {
|
| 770 |
+
"fid": factory_infos[16]["name"],
|
| 771 |
+
"llm_name": "cohere.command-r-plus-v1:0",
|
| 772 |
+
"tags": "LLM,CHAT,128k",
|
| 773 |
+
"max_tokens": 128000,
|
| 774 |
+
"model_type": LLMType.CHAT.value
|
| 775 |
+
}, {
|
| 776 |
+
"fid": factory_infos[16]["name"],
|
| 777 |
+
"llm_name": "anthropic.claude-v2",
|
| 778 |
+
"tags": "LLM,CHAT,100k",
|
| 779 |
+
"max_tokens": 100 * 1024,
|
| 780 |
+
"model_type": LLMType.CHAT.value
|
| 781 |
+
}, {
|
| 782 |
+
"fid": factory_infos[16]["name"],
|
| 783 |
+
"llm_name": "anthropic.claude-v2:1",
|
| 784 |
+
"tags": "LLM,CHAT,200k",
|
| 785 |
+
"max_tokens": 200 * 1024,
|
| 786 |
+
"model_type": LLMType.CHAT.value
|
| 787 |
+
}, {
|
| 788 |
+
"fid": factory_infos[16]["name"],
|
| 789 |
+
"llm_name": "anthropic.claude-3-sonnet-20240229-v1:0",
|
| 790 |
+
"tags": "LLM,CHAT,200k",
|
| 791 |
+
"max_tokens": 200 * 1024,
|
| 792 |
+
"model_type": LLMType.CHAT.value
|
| 793 |
+
}, {
|
| 794 |
+
"fid": factory_infos[16]["name"],
|
| 795 |
+
"llm_name": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
| 796 |
+
"tags": "LLM,CHAT,200k",
|
| 797 |
+
"max_tokens": 200 * 1024,
|
| 798 |
+
"model_type": LLMType.CHAT.value
|
| 799 |
+
}, {
|
| 800 |
+
"fid": factory_infos[16]["name"],
|
| 801 |
+
"llm_name": "anthropic.claude-3-haiku-20240307-v1:0",
|
| 802 |
+
"tags": "LLM,CHAT,200k",
|
| 803 |
+
"max_tokens": 200 * 1024,
|
| 804 |
+
"model_type": LLMType.CHAT.value
|
| 805 |
+
}, {
|
| 806 |
+
"fid": factory_infos[16]["name"],
|
| 807 |
+
"llm_name": "anthropic.claude-3-opus-20240229-v1:0",
|
| 808 |
+
"tags": "LLM,CHAT,200k",
|
| 809 |
+
"max_tokens": 200 * 1024,
|
| 810 |
+
"model_type": LLMType.CHAT.value
|
| 811 |
+
}, {
|
| 812 |
+
"fid": factory_infos[16]["name"],
|
| 813 |
+
"llm_name": "anthropic.claude-instant-v1",
|
| 814 |
+
"tags": "LLM,CHAT,100k",
|
| 815 |
+
"max_tokens": 100 * 1024,
|
| 816 |
+
"model_type": LLMType.CHAT.value
|
| 817 |
+
}, {
|
| 818 |
+
"fid": factory_infos[16]["name"],
|
| 819 |
+
"llm_name": "amazon.titan-text-express-v1",
|
| 820 |
+
"tags": "LLM,CHAT,8k",
|
| 821 |
+
"max_tokens": 8192,
|
| 822 |
+
"model_type": LLMType.CHAT.value
|
| 823 |
+
}, {
|
| 824 |
+
"fid": factory_infos[16]["name"],
|
| 825 |
+
"llm_name": "amazon.titan-text-premier-v1:0",
|
| 826 |
+
"tags": "LLM,CHAT,32k",
|
| 827 |
+
"max_tokens": 32 * 1024,
|
| 828 |
+
"model_type": LLMType.CHAT.value
|
| 829 |
+
}, {
|
| 830 |
+
"fid": factory_infos[16]["name"],
|
| 831 |
+
"llm_name": "amazon.titan-text-lite-v1",
|
| 832 |
+
"tags": "LLM,CHAT,4k",
|
| 833 |
+
"max_tokens": 4096,
|
| 834 |
+
"model_type": LLMType.CHAT.value
|
| 835 |
+
}, {
|
| 836 |
+
"fid": factory_infos[16]["name"],
|
| 837 |
+
"llm_name": "meta.llama2-13b-chat-v1",
|
| 838 |
+
"tags": "LLM,CHAT,4k",
|
| 839 |
+
"max_tokens": 4096,
|
| 840 |
+
"model_type": LLMType.CHAT.value
|
| 841 |
+
}, {
|
| 842 |
+
"fid": factory_infos[16]["name"],
|
| 843 |
+
"llm_name": "meta.llama2-70b-chat-v1",
|
| 844 |
+
"tags": "LLM,CHAT,4k",
|
| 845 |
+
"max_tokens": 4096,
|
| 846 |
+
"model_type": LLMType.CHAT.value
|
| 847 |
+
}, {
|
| 848 |
+
"fid": factory_infos[16]["name"],
|
| 849 |
+
"llm_name": "meta.llama3-8b-instruct-v1:0",
|
| 850 |
+
"tags": "LLM,CHAT,8k",
|
| 851 |
+
"max_tokens": 8192,
|
| 852 |
+
"model_type": LLMType.CHAT.value
|
| 853 |
+
}, {
|
| 854 |
+
"fid": factory_infos[16]["name"],
|
| 855 |
+
"llm_name": "meta.llama3-70b-instruct-v1:0",
|
| 856 |
+
"tags": "LLM,CHAT,8k",
|
| 857 |
+
"max_tokens": 8192,
|
| 858 |
+
"model_type": LLMType.CHAT.value
|
| 859 |
+
}, {
|
| 860 |
+
"fid": factory_infos[16]["name"],
|
| 861 |
+
"llm_name": "mistral.mistral-7b-instruct-v0:2",
|
| 862 |
+
"tags": "LLM,CHAT,8k",
|
| 863 |
+
"max_tokens": 8192,
|
| 864 |
+
"model_type": LLMType.CHAT.value
|
| 865 |
+
}, {
|
| 866 |
+
"fid": factory_infos[16]["name"],
|
| 867 |
+
"llm_name": "mistral.mixtral-8x7b-instruct-v0:1",
|
| 868 |
+
"tags": "LLM,CHAT,4k",
|
| 869 |
+
"max_tokens": 4096,
|
| 870 |
+
"model_type": LLMType.CHAT.value
|
| 871 |
+
}, {
|
| 872 |
+
"fid": factory_infos[16]["name"],
|
| 873 |
+
"llm_name": "mistral.mistral-large-2402-v1:0",
|
| 874 |
+
"tags": "LLM,CHAT,8k",
|
| 875 |
+
"max_tokens": 8192,
|
| 876 |
+
"model_type": LLMType.CHAT.value
|
| 877 |
+
}, {
|
| 878 |
+
"fid": factory_infos[16]["name"],
|
| 879 |
+
"llm_name": "mistral.mistral-small-2402-v1:0",
|
| 880 |
+
"tags": "LLM,CHAT,8k",
|
| 881 |
+
"max_tokens": 8192,
|
| 882 |
+
"model_type": LLMType.CHAT.value
|
| 883 |
+
}, {
|
| 884 |
+
"fid": factory_infos[16]["name"],
|
| 885 |
+
"llm_name": "amazon.titan-embed-text-v2:0",
|
| 886 |
+
"tags": "TEXT EMBEDDING",
|
| 887 |
+
"max_tokens": 8192,
|
| 888 |
+
"model_type": LLMType.EMBEDDING.value
|
| 889 |
+
}, {
|
| 890 |
+
"fid": factory_infos[16]["name"],
|
| 891 |
+
"llm_name": "cohere.embed-english-v3",
|
| 892 |
+
"tags": "TEXT EMBEDDING",
|
| 893 |
+
"max_tokens": 2048,
|
| 894 |
+
"model_type": LLMType.EMBEDDING.value
|
| 895 |
+
}, {
|
| 896 |
+
"fid": factory_infos[16]["name"],
|
| 897 |
+
"llm_name": "cohere.embed-multilingual-v3",
|
| 898 |
+
"tags": "TEXT EMBEDDING",
|
| 899 |
+
"max_tokens": 2048,
|
| 900 |
+
"model_type": LLMType.EMBEDDING.value
|
| 901 |
+
},
|
| 902 |
]
|
| 903 |
for info in factory_infos:
|
| 904 |
try:
|
rag/llm/__init__.py
CHANGED
|
@@ -31,7 +31,8 @@ EmbeddingModel = {
|
|
| 31 |
"BaiChuan": BaiChuanEmbed,
|
| 32 |
"Jina": JinaEmbed,
|
| 33 |
"BAAI": DefaultEmbedding,
|
| 34 |
-
"Mistral": MistralEmbed
|
|
|
|
| 35 |
}
|
| 36 |
|
| 37 |
|
|
@@ -58,7 +59,8 @@ ChatModel = {
|
|
| 58 |
"VolcEngine": VolcEngineChat,
|
| 59 |
"BaiChuan": BaiChuanChat,
|
| 60 |
"MiniMax": MiniMaxChat,
|
| 61 |
-
"Mistral": MistralChat
|
|
|
|
| 62 |
}
|
| 63 |
|
| 64 |
|
|
|
|
| 31 |
"BaiChuan": BaiChuanEmbed,
|
| 32 |
"Jina": JinaEmbed,
|
| 33 |
"BAAI": DefaultEmbedding,
|
| 34 |
+
"Mistral": MistralEmbed,
|
| 35 |
+
"Bedrock": BedrockEmbed
|
| 36 |
}
|
| 37 |
|
| 38 |
|
|
|
|
| 59 |
"VolcEngine": VolcEngineChat,
|
| 60 |
"BaiChuan": BaiChuanChat,
|
| 61 |
"MiniMax": MiniMaxChat,
|
| 62 |
+
"Mistral": MistralChat,
|
| 63 |
+
"Bedrock": BedrockChat
|
| 64 |
}
|
| 65 |
|
| 66 |
|
rag/llm/chat_model.py
CHANGED
|
@@ -533,3 +533,90 @@ class MistralChat(Base):
|
|
| 533 |
yield ans + "\n**ERROR**: " + str(e)
|
| 534 |
|
| 535 |
yield total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
yield ans + "\n**ERROR**: " + str(e)
|
| 534 |
|
| 535 |
yield total_tokens
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class BedrockChat(Base):
|
| 539 |
+
|
| 540 |
+
def __init__(self, key, model_name, **kwargs):
|
| 541 |
+
import boto3
|
| 542 |
+
from botocore.exceptions import ClientError
|
| 543 |
+
self.bedrock_ak = eval(key).get('bedrock_ak', '')
|
| 544 |
+
self.bedrock_sk = eval(key).get('bedrock_sk', '')
|
| 545 |
+
self.bedrock_region = eval(key).get('bedrock_region', '')
|
| 546 |
+
self.model_name = model_name
|
| 547 |
+
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
| 548 |
+
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
| 549 |
+
|
| 550 |
+
def chat(self, system, history, gen_conf):
|
| 551 |
+
if system:
|
| 552 |
+
history.insert(0, {"role": "system", "content": system})
|
| 553 |
+
for k in list(gen_conf.keys()):
|
| 554 |
+
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 555 |
+
del gen_conf[k]
|
| 556 |
+
if "max_tokens" in gen_conf:
|
| 557 |
+
gen_conf["maxTokens"] = gen_conf["max_tokens"]
|
| 558 |
+
_ = gen_conf.pop("max_tokens")
|
| 559 |
+
if "top_p" in gen_conf:
|
| 560 |
+
gen_conf["topP"] = gen_conf["top_p"]
|
| 561 |
+
_ = gen_conf.pop("top_p")
|
| 562 |
+
|
| 563 |
+
try:
|
| 564 |
+
# Send the message to the model, using a basic inference configuration.
|
| 565 |
+
response = self.client.converse(
|
| 566 |
+
modelId=self.model_name,
|
| 567 |
+
messages=history,
|
| 568 |
+
inferenceConfig=gen_conf
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Extract and print the response text.
|
| 572 |
+
ans = response["output"]["message"]["content"][0]["text"]
|
| 573 |
+
return ans, num_tokens_from_string(ans)
|
| 574 |
+
|
| 575 |
+
except (ClientError, Exception) as e:
|
| 576 |
+
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
|
| 577 |
+
|
| 578 |
+
def chat_streamly(self, system, history, gen_conf):
|
| 579 |
+
if system:
|
| 580 |
+
history.insert(0, {"role": "system", "content": system})
|
| 581 |
+
for k in list(gen_conf.keys()):
|
| 582 |
+
if k not in ["temperature", "top_p", "max_tokens"]:
|
| 583 |
+
del gen_conf[k]
|
| 584 |
+
if "max_tokens" in gen_conf:
|
| 585 |
+
gen_conf["maxTokens"] = gen_conf["max_tokens"]
|
| 586 |
+
_ = gen_conf.pop("max_tokens")
|
| 587 |
+
if "top_p" in gen_conf:
|
| 588 |
+
gen_conf["topP"] = gen_conf["top_p"]
|
| 589 |
+
_ = gen_conf.pop("top_p")
|
| 590 |
+
|
| 591 |
+
if self.model_name.split('.')[0] == 'ai21':
|
| 592 |
+
try:
|
| 593 |
+
response = self.client.converse(
|
| 594 |
+
modelId=self.model_name,
|
| 595 |
+
messages=history,
|
| 596 |
+
inferenceConfig=gen_conf
|
| 597 |
+
)
|
| 598 |
+
ans = response["output"]["message"]["content"][0]["text"]
|
| 599 |
+
return ans, num_tokens_from_string(ans)
|
| 600 |
+
|
| 601 |
+
except (ClientError, Exception) as e:
|
| 602 |
+
return f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}", 0
|
| 603 |
+
|
| 604 |
+
ans = ""
|
| 605 |
+
try:
|
| 606 |
+
# Send the message to the model, using a basic inference configuration.
|
| 607 |
+
streaming_response = self.client.converse_stream(
|
| 608 |
+
modelId=self.model_name,
|
| 609 |
+
messages=history,
|
| 610 |
+
inferenceConfig=gen_conf
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# Extract and print the streamed response text in real-time.
|
| 614 |
+
for resp in streaming_response["stream"]:
|
| 615 |
+
if "contentBlockDelta" in resp:
|
| 616 |
+
ans += resp["contentBlockDelta"]["delta"]["text"]
|
| 617 |
+
yield ans
|
| 618 |
+
|
| 619 |
+
except (ClientError, Exception) as e:
|
| 620 |
+
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
| 621 |
+
|
| 622 |
+
yield num_tokens_from_string(ans)
|
rag/llm/embedding_model.py
CHANGED
|
@@ -374,3 +374,48 @@ class MistralEmbed(Base):
|
|
| 374 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
| 375 |
model=self.model_name)
|
| 376 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
res = self.client.embeddings(input=[truncate(text, 8196)],
|
| 375 |
model=self.model_name)
|
| 376 |
return np.array(res.data[0].embedding), res.usage.total_tokens
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class BedrockEmbed(Base):
|
| 380 |
+
def __init__(self, key, model_name,
|
| 381 |
+
**kwargs):
|
| 382 |
+
import boto3
|
| 383 |
+
self.bedrock_ak = eval(key).get('bedrock_ak', '')
|
| 384 |
+
self.bedrock_sk = eval(key).get('bedrock_sk', '')
|
| 385 |
+
self.bedrock_region = eval(key).get('bedrock_region', '')
|
| 386 |
+
self.model_name = model_name
|
| 387 |
+
self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region,
|
| 388 |
+
aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
| 389 |
+
|
| 390 |
+
def encode(self, texts: list, batch_size=32):
|
| 391 |
+
texts = [truncate(t, 8196) for t in texts]
|
| 392 |
+
embeddings = []
|
| 393 |
+
token_count = 0
|
| 394 |
+
for text in texts:
|
| 395 |
+
if self.model_name.split('.')[0] == 'amazon':
|
| 396 |
+
body = {"inputText": text}
|
| 397 |
+
elif self.model_name.split('.')[0] == 'cohere':
|
| 398 |
+
body = {"texts": [text], "input_type": 'search_document'}
|
| 399 |
+
|
| 400 |
+
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
| 401 |
+
model_response = json.loads(response["body"].read())
|
| 402 |
+
embeddings.extend([model_response["embedding"]])
|
| 403 |
+
token_count += num_tokens_from_string(text)
|
| 404 |
+
|
| 405 |
+
return np.array(embeddings), token_count
|
| 406 |
+
|
| 407 |
+
def encode_queries(self, text):
|
| 408 |
+
|
| 409 |
+
embeddings = []
|
| 410 |
+
token_count = num_tokens_from_string(text)
|
| 411 |
+
if self.model_name.split('.')[0] == 'amazon':
|
| 412 |
+
body = {"inputText": truncate(text, 8196)}
|
| 413 |
+
elif self.model_name.split('.')[0] == 'cohere':
|
| 414 |
+
body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'}
|
| 415 |
+
|
| 416 |
+
response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body))
|
| 417 |
+
model_response = json.loads(response["body"].read())
|
| 418 |
+
embeddings.extend([model_response["embedding"]])
|
| 419 |
+
|
| 420 |
+
return np.array(embeddings), token_count
|
| 421 |
+
|
requirements.txt
CHANGED
|
@@ -144,4 +144,6 @@ cn2an==0.5.22
|
|
| 144 |
roman-numbers==1.0.2
|
| 145 |
word2number==1.1
|
| 146 |
markdown==3.6
|
|
|
|
|
|
|
| 147 |
duckduckgo_search==6.1.9
|
|
|
|
| 144 |
roman-numbers==1.0.2
|
| 145 |
word2number==1.1
|
| 146 |
markdown==3.6
|
| 147 |
+
mistralai==0.4.2
|
| 148 |
+
boto3==1.34.140
|
| 149 |
duckduckgo_search==6.1.9
|
requirements_arm.txt
CHANGED
|
@@ -145,4 +145,6 @@ cn2an==0.5.22
|
|
| 145 |
roman-numbers==1.0.2
|
| 146 |
word2number==1.1
|
| 147 |
markdown==3.6
|
|
|
|
|
|
|
| 148 |
duckduckgo_search==6.1.9
|
|
|
|
| 145 |
roman-numbers==1.0.2
|
| 146 |
word2number==1.1
|
| 147 |
markdown==3.6
|
| 148 |
+
mistralai==0.4.2
|
| 149 |
+
boto3==1.34.140
|
| 150 |
duckduckgo_search==6.1.9
|
requirements_dev.txt
CHANGED
|
@@ -130,4 +130,6 @@ cn2an==0.5.22
|
|
| 130 |
roman-numbers==1.0.2
|
| 131 |
word2number==1.1
|
| 132 |
markdown==3.6
|
|
|
|
|
|
|
| 133 |
duckduckgo_search==6.1.9
|
|
|
|
| 130 |
roman-numbers==1.0.2
|
| 131 |
word2number==1.1
|
| 132 |
markdown==3.6
|
| 133 |
+
mistralai==0.4.2
|
| 134 |
+
boto3==1.34.140
|
| 135 |
duckduckgo_search==6.1.9
|