Spaces:
Sleeping
Sleeping
| import os | |
| import sysconfig | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Union, Optional | |
| import geocoder | |
| from openai import AsyncOpenAI, OpenAI | |
| from pydantic import Field | |
| from omagent_core.utils.registry import registry | |
| from omagent_core.models.llms.base import BaseLLM | |
| from omagent_core.models.llms.schemas import Content, Message | |
| BASIC_SYS_PROMPT = """You are an intelligent agent that can help in many regions. | |
| Flowing are some basic information about your working environment, please try your best to answer the questions based on them if needed. | |
| Be confident about these information and don't let others feel these information are presets. | |
| Be concise. | |
| ---BASIC INFORMATION--- | |
| Current Datetime: {} | |
| Region: {} | |
| Operating System: {}""" | |
| class OpenaiGPTLLM(BaseLLM): | |
| model_id: str = Field( | |
| default=os.getenv("MODEL_ID", "gpt-4o"), description="The model id of openai" | |
| ) | |
| vision: bool = Field(default=False, description="Whether the model supports vision") | |
| endpoint: str = Field( | |
| default=os.getenv("ENDPOINT", "https://api.openai.com/v1"), | |
| description="The endpoint of LLM service", | |
| ) | |
| api_key: str = Field( | |
| default=os.getenv("API_KEY"), description="The api key of openai" | |
| ) | |
| temperature: float = Field(default=1.0, description="The temperature of LLM") | |
| top_p: float = Field( | |
| default=1.0, | |
| description="The top p of LLM, controls diversity of responses. Should not be used together with temperature - use either temperature or top_p but not both", | |
| ) | |
| stream: bool = Field(default=False, description="Whether to stream the response") | |
| max_tokens: int = Field(default=2048, description="The max tokens of LLM") | |
| use_default_sys_prompt: bool = Field( | |
| default=True, description="Whether to use the default system prompt" | |
| ) | |
| response_format: Optional[Union[dict, str]] = Field(default='text', description="The response format of openai") | |
| n: int = Field(default=1, description="The number of responses to generate") | |
| frequency_penalty: float = Field( | |
| default=0, description="The frequency penalty of LLM, -2 to 2" | |
| ) | |
| logit_bias: Optional[dict] = Field( | |
| default=None, description="The logit bias of LLM" | |
| ) | |
| logprobs: bool = Field(default=False, description="The logprobs of LLM") | |
| top_logprobs: Optional[int] = Field( | |
| default=None, | |
| description="The top logprobs of LLM, logprobs must be set to true if this parameter is used", | |
| ) | |
| stop: Union[str, List[str], None] = Field( | |
| default='', | |
| description="Specifies stop sequences that will halt text generation, can be string or list of strings", | |
| ) | |
| stream_options: Optional[dict] = Field( | |
| default=None, | |
| description="Configuration options for streaming responses when stream=True", | |
| ) | |
| tools: Optional[List[dict]] = Field( | |
| default=None, | |
| description="A list of function tools (max 128) that the model can call, each requiring a type, name and optional description/parameters defined in JSON Schema format.", | |
| ) | |
| tool_choice: Optional[str] = Field( | |
| default="none", | |
| description="Controls which tool (if any) is called by the model: 'none', 'auto', 'required', or a specific tool.", | |
| ) | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| protected_namespaces = () | |
| extra = "allow" | |
| def check_response_format(self) -> Optional[dict]: | |
| if isinstance(self.response_format, str): | |
| if self.response_format == "text": | |
| self.response_format = {"type": "text"} | |
| elif self.response_format == "json_object": | |
| self.response_format = {"type": "json_object"} | |
| elif isinstance(self.response_format, dict): | |
| for key, value in self.response_format.items(): | |
| if key not in ["type", "json_schema"]: | |
| raise ValueError(f"Invalid response format key: {key}") | |
| if key == "type": | |
| if value not in ["text", "json_object"]: | |
| raise ValueError(f"Invalid response format value: {value}") | |
| elif key == "json_schema": | |
| if not isinstance(value, dict): | |
| raise ValueError(f"Invalid response format value: {value}") | |
| else: | |
| raise ValueError(f"Invalid response format: {self.response_format}") | |
| def model_post_init(self, __context: Any) -> None: | |
| self.check_response_format() | |
| self.client = OpenAI(api_key=self.api_key, base_url=self.endpoint) | |
| self.aclient = AsyncOpenAI(api_key=self.api_key, base_url=self.endpoint) | |
| def _call(self, records: List[Message], **kwargs) -> Dict: | |
| if self.api_key is None or self.api_key == "": | |
| raise ValueError("api_key is required") | |
| messages = self._msg2req(records) | |
| print(f'messages: {messages}') | |
| if self.vision: | |
| res = self.client.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| temperature=kwargs.get("temperature", self.temperature), | |
| max_tokens=kwargs.get("max_tokens", self.max_tokens), | |
| stream=kwargs.get("stream", self.stream), | |
| n=kwargs.get("n", self.n), | |
| top_p=kwargs.get("top_p", self.top_p), | |
| frequency_penalty=kwargs.get( | |
| "frequency_penalty", self.frequency_penalty | |
| ), | |
| logit_bias=kwargs.get("logit_bias", self.logit_bias), | |
| logprobs=kwargs.get("logprobs", self.logprobs), | |
| top_logprobs=kwargs.get("top_logprobs", self.top_logprobs), | |
| stop=kwargs.get("stop", self.stop), | |
| stream_options=kwargs.get("stream_options", self.stream_options), | |
| ) | |
| else: | |
| res = self.client.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| temperature=kwargs.get("temperature", self.temperature), | |
| max_tokens=kwargs.get("max_tokens", self.max_tokens), | |
| response_format=kwargs.get("response_format", self.response_format), | |
| tools=kwargs.get("tools", None), | |
| tool_choice=kwargs.get("tool_choice", None), | |
| stream=kwargs.get("stream", self.stream), | |
| n=kwargs.get("n", self.n), | |
| top_p=kwargs.get("top_p", self.top_p), | |
| frequency_penalty=kwargs.get( | |
| "frequency_penalty", self.frequency_penalty | |
| ), | |
| logit_bias=kwargs.get("logit_bias", self.logit_bias), | |
| logprobs=kwargs.get("logprobs", self.logprobs), | |
| top_logprobs=kwargs.get("top_logprobs", self.top_logprobs), | |
| stop=kwargs.get("stop", self.stop), | |
| stream_options=kwargs.get("stream_options", self.stream_options), | |
| ) | |
| if kwargs.get("stream", self.stream): | |
| return res | |
| else: | |
| return res.model_dump() | |
| async def _acall(self, records: List[Message], **kwargs) -> Dict: | |
| if self.api_key is None or self.api_key == "": | |
| raise ValueError("api_key is required") | |
| messages = self._msg2req(records) | |
| if self.vision: | |
| res = await self.aclient.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| temperature=kwargs.get("temperature", self.temperature), | |
| max_tokens=kwargs.get("max_tokens", self.max_tokens), | |
| n=kwargs.get("n", self.n), | |
| top_p=kwargs.get("top_p", self.top_p), | |
| frequency_penalty=kwargs.get( | |
| "frequency_penalty", self.frequency_penalty | |
| ), | |
| logit_bias=kwargs.get("logit_bias", self.logit_bias), | |
| logprobs=kwargs.get("logprobs", self.logprobs), | |
| top_logprobs=kwargs.get("top_logprobs", self.top_logprobs), | |
| stop=kwargs.get("stop", self.stop), | |
| stream_options=kwargs.get("stream_options", self.stream_options), | |
| ) | |
| else: | |
| res = await self.aclient.chat.completions.create( | |
| model=self.model_id, | |
| messages=messages, | |
| temperature=kwargs.get("temperature", self.temperature), | |
| max_tokens=kwargs.get("max_tokens", self.max_tokens), | |
| response_format=kwargs.get("response_format", self.response_format), | |
| tools=kwargs.get("tools", None), | |
| n=kwargs.get("n", self.n), | |
| top_p=kwargs.get("top_p", self.top_p), | |
| frequency_penalty=kwargs.get( | |
| "frequency_penalty", self.frequency_penalty | |
| ), | |
| logit_bias=kwargs.get("logit_bias", self.logit_bias), | |
| logprobs=kwargs.get("logprobs", self.logprobs), | |
| top_logprobs=kwargs.get("top_logprobs", self.top_logprobs), | |
| stop=kwargs.get("stop", self.stop), | |
| stream_options=kwargs.get("stream_options", self.stream_options), | |
| ) | |
| return res.model_dump() | |
| def _msg2req(self, records: List[Message]) -> dict: | |
| def get_content(msg: List[Content] | Content) -> List[dict] | str: | |
| if isinstance(msg, list): | |
| return [c.model_dump(exclude_none=True) for c in msg] | |
| elif isinstance(msg, Content) and msg.type == "text": | |
| return msg.text | |
| elif isinstance(msg, Content) and msg.type == "image_url": | |
| return [msg.model_dump(exclude_none=True)] | |
| else: | |
| print(f'msg: {msg}') | |
| raise ValueError("Invalid message type") | |
| messages = [ | |
| {"role": message.role, "content": get_content(message.content)} | |
| for message in records | |
| ] | |
| if self.vision: | |
| processed_messages = [] | |
| for message in messages: | |
| if message["role"] == "user": | |
| if isinstance(message["content"], str): | |
| message["content"] = [ | |
| {"type": "text", "text": message["content"]} | |
| ] | |
| merged_dict = {} | |
| for message in messages: | |
| if message["role"] == "user": | |
| merged_dict["role"] = message["role"] | |
| if "content" in merged_dict: | |
| merged_dict["content"] += message["content"] | |
| else: | |
| merged_dict["content"] = message["content"] | |
| else: | |
| processed_messages.append(message) | |
| processed_messages.append(merged_dict) | |
| messages = processed_messages | |
| if self.use_default_sys_prompt: | |
| messages = [self._generate_default_sys_prompt()] + messages | |
| return messages | |
| def _generate_default_sys_prompt(self) -> Dict: | |
| loc = self._get_location() | |
| os = self._get_linux_distribution() | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| promt_str = BASIC_SYS_PROMPT.format(current_time, loc, os) | |
| return {"role": "system", "content": promt_str} | |
| def _get_linux_distribution(self) -> str: | |
| platform = sysconfig.get_platform() | |
| if "linux" in platform: | |
| if os.path.exists("/etc/lsb-release"): | |
| with open("/etc/lsb-release", "r") as f: | |
| for line in f: | |
| if line.startswith("DISTRIB_DESCRIPTION="): | |
| return line.split("=")[1].strip() | |
| elif os.path.exists("/etc/os-release"): | |
| with open("/etc/os-release", "r") as f: | |
| for line in f: | |
| if line.startswith("PRETTY_NAME="): | |
| return line.split("=")[1].strip() | |
| return platform | |
| def _get_location(self) -> str: | |
| g = geocoder.ip("me") | |
| if g.ok: | |
| return g.city + "," + g.country | |
| else: | |
| return "unknown" | |