Spaces:
Running
Running
| from openai import OpenAI | |
| import pdb | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.globals import get_llm_cache | |
| from langchain_core.language_models.base import ( | |
| BaseLanguageModel, | |
| LangSmithParams, | |
| LanguageModelInput, | |
| ) | |
| from langchain_core.load import dumpd, dumps | |
| from langchain_core.messages import ( | |
| AIMessage, | |
| SystemMessage, | |
| AnyMessage, | |
| BaseMessage, | |
| BaseMessageChunk, | |
| HumanMessage, | |
| convert_to_messages, | |
| message_chunk_to_message, | |
| ) | |
| from langchain_core.outputs import ( | |
| ChatGeneration, | |
| ChatGenerationChunk, | |
| ChatResult, | |
| LLMResult, | |
| RunInfo, | |
| ) | |
| from langchain_ollama import ChatOllama | |
| from langchain_core.output_parsers.base import OutputParserLike | |
| from langchain_core.runnables import Runnable, RunnableConfig | |
| from langchain_core.tools import BaseTool | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Literal, | |
| Optional, | |
| Union, | |
| cast, List, | |
| ) | |
| class DeepSeekR1ChatOpenAI(ChatOpenAI): | |
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |
| super().__init__(*args, **kwargs) | |
| self.client = OpenAI( | |
| base_url=kwargs.get("base_url"), | |
| api_key=kwargs.get("api_key") | |
| ) | |
| async def ainvoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| message_history = [] | |
| for input_ in input: | |
| if isinstance(input_, SystemMessage): | |
| message_history.append({"role": "system", "content": input_.content}) | |
| elif isinstance(input_, AIMessage): | |
| message_history.append({"role": "assistant", "content": input_.content}) | |
| else: | |
| message_history.append({"role": "user", "content": input_.content}) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=message_history | |
| ) | |
| reasoning_content = response.choices[0].message.reasoning_content | |
| content = response.choices[0].message.content | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| def invoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| message_history = [] | |
| for input_ in input: | |
| if isinstance(input_, SystemMessage): | |
| message_history.append({"role": "system", "content": input_.content}) | |
| elif isinstance(input_, AIMessage): | |
| message_history.append({"role": "assistant", "content": input_.content}) | |
| else: | |
| message_history.append({"role": "user", "content": input_.content}) | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=message_history | |
| ) | |
| reasoning_content = response.choices[0].message.reasoning_content | |
| content = response.choices[0].message.content | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| class DeepSeekR1ChatOllama(ChatOllama): | |
| async def ainvoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| org_ai_message = await super().ainvoke(input=input) | |
| org_content = org_ai_message.content | |
| reasoning_content = org_content.split("</think>")[0].replace("<think>", "") | |
| content = org_content.split("</think>")[1] | |
| if "**JSON Response:**" in content: | |
| content = content.split("**JSON Response:**")[-1] | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |
| def invoke( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[list[str]] = None, | |
| **kwargs: Any, | |
| ) -> AIMessage: | |
| org_ai_message = super().invoke(input=input) | |
| org_content = org_ai_message.content | |
| reasoning_content = org_content.split("</think>")[0].replace("<think>", "") | |
| content = org_content.split("</think>")[1] | |
| if "**JSON Response:**" in content: | |
| content = content.split("**JSON Response:**")[-1] | |
| return AIMessage(content=content, reasoning_content=reasoning_content) | |