|
from openai import OpenAI |
|
import pdb |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.globals import get_llm_cache |
|
from langchain_core.language_models.base import ( |
|
BaseLanguageModel, |
|
LangSmithParams, |
|
LanguageModelInput, |
|
) |
|
import os |
|
from langchain_core.load import dumpd, dumps |
|
from langchain_core.messages import ( |
|
AIMessage, |
|
SystemMessage, |
|
AnyMessage, |
|
BaseMessage, |
|
BaseMessageChunk, |
|
HumanMessage, |
|
convert_to_messages, |
|
message_chunk_to_message, |
|
) |
|
from langchain_core.outputs import ( |
|
ChatGeneration, |
|
ChatGenerationChunk, |
|
ChatResult, |
|
LLMResult, |
|
RunInfo, |
|
) |
|
from langchain_ollama import ChatOllama |
|
from langchain_core.output_parsers.base import OutputParserLike |
|
from langchain_core.runnables import Runnable, RunnableConfig |
|
from langchain_core.tools import BaseTool |
|
|
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
Callable, |
|
Literal, |
|
Optional, |
|
Union, |
|
cast, List, |
|
) |
|
from langchain_anthropic import ChatAnthropic |
|
from langchain_mistralai import ChatMistralAI |
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
from langchain_ollama import ChatOllama |
|
from langchain_openai import AzureChatOpenAI, ChatOpenAI |
|
from langchain_ibm import ChatWatsonx |
|
from langchain_aws import ChatBedrock |
|
from pydantic import SecretStr |
|
|
|
from src.utils import config |
|
|
|
|
|
class DeepSeekR1ChatOpenAI(ChatOpenAI): |
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.client = OpenAI( |
|
base_url=kwargs.get("base_url"), |
|
api_key=kwargs.get("api_key") |
|
) |
|
|
|
async def ainvoke( |
|
self, |
|
input: LanguageModelInput, |
|
config: Optional[RunnableConfig] = None, |
|
*, |
|
stop: Optional[list[str]] = None, |
|
**kwargs: Any, |
|
) -> AIMessage: |
|
message_history = [] |
|
for input_ in input: |
|
if isinstance(input_, SystemMessage): |
|
message_history.append({"role": "system", "content": input_.content}) |
|
elif isinstance(input_, AIMessage): |
|
message_history.append({"role": "assistant", "content": input_.content}) |
|
else: |
|
message_history.append({"role": "user", "content": input_.content}) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=message_history |
|
) |
|
|
|
reasoning_content = response.choices[0].message.reasoning_content |
|
content = response.choices[0].message.content |
|
return AIMessage(content=content, reasoning_content=reasoning_content) |
|
|
|
def invoke( |
|
self, |
|
input: LanguageModelInput, |
|
config: Optional[RunnableConfig] = None, |
|
*, |
|
stop: Optional[list[str]] = None, |
|
**kwargs: Any, |
|
) -> AIMessage: |
|
message_history = [] |
|
for input_ in input: |
|
if isinstance(input_, SystemMessage): |
|
message_history.append({"role": "system", "content": input_.content}) |
|
elif isinstance(input_, AIMessage): |
|
message_history.append({"role": "assistant", "content": input_.content}) |
|
else: |
|
message_history.append({"role": "user", "content": input_.content}) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=message_history |
|
) |
|
|
|
reasoning_content = response.choices[0].message.reasoning_content |
|
content = response.choices[0].message.content |
|
return AIMessage(content=content, reasoning_content=reasoning_content) |
|
|
|
|
|
class DeepSeekR1ChatOllama(ChatOllama): |
|
|
|
async def ainvoke( |
|
self, |
|
input: LanguageModelInput, |
|
config: Optional[RunnableConfig] = None, |
|
*, |
|
stop: Optional[list[str]] = None, |
|
**kwargs: Any, |
|
) -> AIMessage: |
|
org_ai_message = await super().ainvoke(input=input) |
|
org_content = org_ai_message.content |
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "") |
|
content = org_content.split("</think>")[1] |
|
if "**JSON Response:**" in content: |
|
content = content.split("**JSON Response:**")[-1] |
|
return AIMessage(content=content, reasoning_content=reasoning_content) |
|
|
|
def invoke( |
|
self, |
|
input: LanguageModelInput, |
|
config: Optional[RunnableConfig] = None, |
|
*, |
|
stop: Optional[list[str]] = None, |
|
**kwargs: Any, |
|
) -> AIMessage: |
|
org_ai_message = super().invoke(input=input) |
|
org_content = org_ai_message.content |
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "") |
|
content = org_content.split("</think>")[1] |
|
if "**JSON Response:**" in content: |
|
content = content.split("**JSON Response:**")[-1] |
|
return AIMessage(content=content, reasoning_content=reasoning_content) |
|
|
|
|
|
def get_llm_model(provider: str, **kwargs): |
|
""" |
|
Get LLM model |
|
:param provider: LLM provider |
|
:param kwargs: |
|
:return: |
|
""" |
|
if provider not in ["ollama", "bedrock"]: |
|
env_var = f"{provider.upper()}_API_KEY" |
|
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "") |
|
if not api_key: |
|
provider_display = config.PROVIDER_DISPLAY_NAMES.get(provider, provider.upper()) |
|
error_msg = f"💥 {provider_display} API key not found! 🔑 Please set the `{env_var}` environment variable or provide it in the UI." |
|
raise ValueError(error_msg) |
|
kwargs["api_key"] = api_key |
|
|
|
if provider == "anthropic": |
|
if not kwargs.get("base_url", ""): |
|
base_url = "https://api.anthropic.com" |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
return ChatAnthropic( |
|
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == 'mistral': |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
if not kwargs.get("api_key", ""): |
|
api_key = os.getenv("MISTRAL_API_KEY", "") |
|
else: |
|
api_key = kwargs.get("api_key") |
|
|
|
return ChatMistralAI( |
|
model=kwargs.get("model_name", "mistral-large-latest"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "openai": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "gpt-4o"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "grok": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "grok-3"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "deepseek": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("DEEPSEEK_ENDPOINT", "") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner": |
|
return DeepSeekR1ChatOpenAI( |
|
model=kwargs.get("model_name", "deepseek-reasoner"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
else: |
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "deepseek-chat"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "google": |
|
return ChatGoogleGenerativeAI( |
|
model=kwargs.get("model_name", "gemini-2.0-flash-exp"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
api_key=api_key, |
|
) |
|
elif provider == "ollama": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"): |
|
return DeepSeekR1ChatOllama( |
|
model=kwargs.get("model_name", "deepseek-r1:14b"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
num_ctx=kwargs.get("num_ctx", 32000), |
|
base_url=base_url, |
|
) |
|
else: |
|
return ChatOllama( |
|
model=kwargs.get("model_name", "qwen2.5:7b"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
num_ctx=kwargs.get("num_ctx", 32000), |
|
num_predict=kwargs.get("num_predict", 1024), |
|
base_url=base_url, |
|
) |
|
elif provider == "azure_openai": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview") |
|
return AzureChatOpenAI( |
|
model=kwargs.get("model_name", "gpt-4o"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
api_version=api_version, |
|
azure_endpoint=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "alibaba": |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "qwen-plus"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=base_url, |
|
api_key=api_key, |
|
) |
|
elif provider == "ibm": |
|
parameters = { |
|
"temperature": kwargs.get("temperature", 0.0), |
|
"max_tokens": kwargs.get("num_ctx", 32000) |
|
} |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
|
|
return ChatWatsonx( |
|
model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"), |
|
url=base_url, |
|
project_id=os.getenv("IBM_PROJECT_ID"), |
|
apikey=os.getenv("IBM_API_KEY"), |
|
params=parameters |
|
) |
|
elif provider == "moonshot": |
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=os.getenv("MOONSHOT_ENDPOINT"), |
|
api_key=os.getenv("MOONSHOT_API_KEY"), |
|
) |
|
elif provider == "unbound": |
|
return ChatOpenAI( |
|
model=kwargs.get("model_name", "gpt-4o-mini"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"), |
|
api_key=api_key, |
|
) |
|
elif provider == "siliconflow": |
|
if not kwargs.get("api_key", ""): |
|
api_key = os.getenv("SiliconFLOW_API_KEY", "") |
|
else: |
|
api_key = kwargs.get("api_key") |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("SiliconFLOW_ENDPOINT", "") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
return ChatOpenAI( |
|
api_key=api_key, |
|
base_url=base_url, |
|
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
) |
|
elif provider == "modelscope": |
|
if not kwargs.get("api_key", ""): |
|
api_key = os.getenv("MODELSCOPE_API_KEY", "") |
|
else: |
|
api_key = kwargs.get("api_key") |
|
if not kwargs.get("base_url", ""): |
|
base_url = os.getenv("MODELSCOPE_ENDPOINT", "") |
|
else: |
|
base_url = kwargs.get("base_url") |
|
return ChatOpenAI( |
|
api_key=api_key, |
|
base_url=base_url, |
|
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"), |
|
temperature=kwargs.get("temperature", 0.0), |
|
) |
|
else: |
|
raise ValueError(f"Unsupported provider: {provider}") |
|
|