Spaces:
Sleeping
Sleeping
[email protected]
commited on
Commit
·
ea077e1
1
Parent(s):
89033ee
Implement model integration strategy and selector for multiple AI models
Browse files- model/ModelIntegrations.py +37 -0
- model/ModelStrategy.py +6 -0
- model/selector.py +43 -0
- pages/chatbot.py +3 -0
- rag.py +2 -5
- requirements.txt +3 -0
model/ModelIntegrations.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .ModelStrategy import ModelStrategy
|
| 2 |
+
|
| 3 |
+
from langchain_community.chat_models import ChatOpenAI
|
| 4 |
+
from langchain_mistralai.chat_models import ChatMistralAI
|
| 5 |
+
from langchain_anthropic import ChatAnthropic
|
| 6 |
+
from langchain_ollama import ChatOllama
|
| 7 |
+
|
| 8 |
+
class MistralModel(ModelStrategy):
|
| 9 |
+
def get_model(self, model_name):
|
| 10 |
+
return ChatMistralAI(model=model_name)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class OpenAIModel(ModelStrategy):
|
| 14 |
+
def get_model(self, model_name):
|
| 15 |
+
return ChatOpenAI(model=model_name)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AnthropicModel(ModelStrategy):
|
| 19 |
+
def get_model(self, model_name):
|
| 20 |
+
return ChatAnthropic(model=model_name)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OllamaModel(ModelStrategy):
|
| 24 |
+
def get_model(self, model_name):
|
| 25 |
+
return ChatOllama(model=model_name)
|
| 26 |
+
|
| 27 |
+
class ModelManager():
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.models = {
|
| 30 |
+
"mistral": MistralModel(),
|
| 31 |
+
"openai": OpenAIModel(),
|
| 32 |
+
"anthropic": AnthropicModel(),
|
| 33 |
+
"ollama": OllamaModel()
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def get_model(self, provider, model_name):
|
| 37 |
+
return self.models[provider].get_model(model_name)
|
model/ModelStrategy.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
class ModelStrategy(ABC):
|
| 4 |
+
@abstractmethod
|
| 5 |
+
def get_model(self, model_name):
|
| 6 |
+
pass
|
model/selector.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from .ModelIntegrations import ModelManager
|
| 3 |
+
|
| 4 |
+
def ModelSelector():
|
| 5 |
+
# Dictionnaire des modèles par fournisseur
|
| 6 |
+
model_providers = {
|
| 7 |
+
"Mistral": {
|
| 8 |
+
"mistral-large-latest": "mistral.mistral-large-latest",
|
| 9 |
+
"open-mixtral-8x7b": "mistral.open-mixtral-8x7b",
|
| 10 |
+
},
|
| 11 |
+
"OpenAI": {
|
| 12 |
+
"gpt-4o": "openai.gpt-4o",
|
| 13 |
+
},
|
| 14 |
+
"Anthropic": {
|
| 15 |
+
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620",
|
| 16 |
+
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229",
|
| 17 |
+
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229",
|
| 18 |
+
},
|
| 19 |
+
# "Ollama": {
|
| 20 |
+
# "llama3": "ollama.llama3"
|
| 21 |
+
# }
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
# Créer une liste avec les noms de modèle, groupés par fournisseur (fournisseur - modèle)
|
| 25 |
+
model_options = []
|
| 26 |
+
model_mapping = {}
|
| 27 |
+
|
| 28 |
+
for provider, models in model_providers.items():
|
| 29 |
+
for model_name, model_instance in models.items():
|
| 30 |
+
option_name = f"{provider} - {model_name}"
|
| 31 |
+
model_options.append(option_name)
|
| 32 |
+
model_mapping[option_name] = model_instance
|
| 33 |
+
|
| 34 |
+
# Sélection d'un modèle via un seul sélecteur
|
| 35 |
+
selected_model_option = st.selectbox("Choisissez votre modèle", options=model_options)
|
| 36 |
+
|
| 37 |
+
# Afficher le modèle sélectionné
|
| 38 |
+
st.write(f"Current model: {model_mapping[selected_model_option]}")
|
| 39 |
+
|
| 40 |
+
if(st.session_state["assistant"]):
|
| 41 |
+
splitter = model_mapping[selected_model_option].split(".")
|
| 42 |
+
st.session_state["assistant"].setModel(ModelManager().get_model(splitter[0], splitter[1]))
|
| 43 |
+
|
pages/chatbot.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit_chat import message
|
|
|
|
| 3 |
|
| 4 |
def display_messages():
|
| 5 |
for i, (msg, is_user) in enumerate(st.session_state["messages"]):
|
|
@@ -22,6 +23,8 @@ def process_input():
|
|
| 22 |
def page():
|
| 23 |
st.subheader("Posez vos questions")
|
| 24 |
|
|
|
|
|
|
|
| 25 |
if "assistant" not in st.session_state:
|
| 26 |
st.text("Assistant non initialisé")
|
| 27 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from streamlit_chat import message
|
| 3 |
+
from model import selector
|
| 4 |
|
| 5 |
def display_messages():
|
| 6 |
for i, (msg, is_user) in enumerate(st.session_state["messages"]):
|
|
|
|
| 23 |
def page():
|
| 24 |
st.subheader("Posez vos questions")
|
| 25 |
|
| 26 |
+
selector.ModelSelector()
|
| 27 |
+
|
| 28 |
if "assistant" not in st.session_state:
|
| 29 |
st.text("Assistant non initialisé")
|
| 30 |
|
rag.py
CHANGED
|
@@ -19,7 +19,6 @@ from prompt_template import base_template
|
|
| 19 |
# load .env in local dev
|
| 20 |
load_dotenv()
|
| 21 |
env_api_key = os.environ.get("MISTRAL_API_KEY")
|
| 22 |
-
llm_model = "open-mixtral-8x7b"
|
| 23 |
|
| 24 |
class Rag:
|
| 25 |
document_vector_store = None
|
|
@@ -28,7 +27,7 @@ class Rag:
|
|
| 28 |
|
| 29 |
def __init__(self, vectore_store=None):
|
| 30 |
|
| 31 |
-
self.model = ChatMistralAI(model=llm_model)
|
| 32 |
self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
|
| 33 |
|
| 34 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
|
|
@@ -73,11 +72,9 @@ class Rag:
|
|
| 73 |
)
|
| 74 |
|
| 75 |
def ask(self, query: str, messages: list):
|
| 76 |
-
|
| 77 |
self.chain = self.prompt | self.model | StrOutputParser()
|
| 78 |
|
| 79 |
-
print("messages ", messages)
|
| 80 |
-
|
| 81 |
# Retrieve the context document
|
| 82 |
if self.retriever is None:
|
| 83 |
documentContext = ''
|
|
|
|
| 19 |
# load .env in local dev
|
| 20 |
load_dotenv()
|
| 21 |
env_api_key = os.environ.get("MISTRAL_API_KEY")
|
|
|
|
| 22 |
|
| 23 |
class Rag:
|
| 24 |
document_vector_store = None
|
|
|
|
| 27 |
|
| 28 |
def __init__(self, vectore_store=None):
|
| 29 |
|
| 30 |
+
# self.model = ChatMistralAI(model=llm_model)
|
| 31 |
self.embedding = MistralAIEmbeddings(model="mistral-embed", mistral_api_key=env_api_key)
|
| 32 |
|
| 33 |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100, length_function=len)
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
def ask(self, query: str, messages: list):
|
| 75 |
+
print(self.model)
|
| 76 |
self.chain = self.prompt | self.model | StrOutputParser()
|
| 77 |
|
|
|
|
|
|
|
| 78 |
# Retrieve the context document
|
| 79 |
if self.retriever is None:
|
| 80 |
documentContext = ''
|
requirements.txt
CHANGED
|
@@ -17,3 +17,6 @@ langchain-openai
|
|
| 17 |
langchain-community
|
| 18 |
langchain-pinecone
|
| 19 |
langchain_mistralai
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
langchain-community
|
| 18 |
langchain-pinecone
|
| 19 |
langchain_mistralai
|
| 20 |
+
langchain_anthropic
|
| 21 |
+
langchain_ollama
|
| 22 |
+
pyyaml
|