Tai Truong
fix readme
d202ada
from typing import Any
from urllib.parse import urljoin
import httpx
from langchain_openai import ChatOpenAI
from pydantic.v1 import SecretStr
from typing_extensions import override
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.field_typing.range_spec import RangeSpec
from langflow.inputs import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
from langflow.inputs.inputs import HandleInput
class LMStudioModelComponent(LCModelComponent):
display_name = "LM Studio"
description = "Generate text using LM Studio Local LLMs."
icon = "LMStudio"
name = "LMStudioModel"
@override
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None):
if field_name == "model_name":
base_url_dict = build_config.get("base_url", {})
base_url_load_from_db = base_url_dict.get("load_from_db", False)
base_url_value = base_url_dict.get("value")
if base_url_load_from_db:
base_url_value = self.variables(base_url_value)
elif not base_url_value:
base_url_value = "http://localhost:1234/v1"
build_config["model_name"]["options"] = self.get_model(base_url_value)
return build_config
def get_model(self, base_url_value: str) -> list[str]:
try:
url = urljoin(base_url_value, "/v1/models")
with httpx.Client() as client:
response = client.get(url)
response.raise_for_status()
data = response.json()
return [model["id"] for model in data.get("data", [])]
except Exception as e:
msg = "Could not retrieve models. Please, make sure the LM Studio server is running."
raise ValueError(msg) from e
inputs = [
*LCModelComponent._base_inputs,
IntInput(
name="max_tokens",
display_name="Max Tokens",
advanced=True,
info="The maximum number of tokens to generate. Set to 0 for unlimited tokens.",
range_spec=RangeSpec(min=0, max=128000),
),
DictInput(name="model_kwargs", display_name="Model Kwargs", advanced=True),
DropdownInput(
name="model_name",
display_name="Model Name",
advanced=False,
refresh_button=True,
),
StrInput(
name="base_url",
display_name="Base URL",
advanced=False,
info="Endpoint of the LM Studio API. Defaults to 'http://localhost:1234/v1' if not specified.",
value="http://localhost:1234/v1",
),
SecretStrInput(
name="api_key",
display_name="LM Studio API Key",
info="The LM Studio API Key to use for LM Studio.",
advanced=True,
value="LMSTUDIO_API_KEY",
),
FloatInput(name="temperature", display_name="Temperature", value=0.1),
IntInput(
name="seed",
display_name="Seed",
info="The seed controls the reproducibility of the job.",
advanced=True,
value=1,
),
HandleInput(
name="output_parser",
display_name="Output Parser",
info="The parser to use to parse the output of the model",
advanced=True,
input_types=["OutputParser"],
),
]
def build_model(self) -> LanguageModel: # type: ignore[type-var]
lmstudio_api_key = self.api_key
temperature = self.temperature
model_name: str = self.model_name
max_tokens = self.max_tokens
model_kwargs = self.model_kwargs or {}
base_url = self.base_url or "http://localhost:1234/v1"
seed = self.seed
api_key = SecretStr(lmstudio_api_key) if lmstudio_api_key else None
return ChatOpenAI(
max_tokens=max_tokens or None,
model_kwargs=model_kwargs,
model=model_name,
base_url=base_url,
api_key=api_key,
temperature=temperature if temperature is not None else 0.1,
seed=seed,
)
def _get_exception_message(self, e: Exception):
"""Get a message from an LM Studio exception.
Args:
e (Exception): The exception to get the message from.
Returns:
str: The message from the exception.
"""
try:
from openai import BadRequestError
except ImportError:
return None
if isinstance(e, BadRequestError):
message = e.body.get("message")
if message:
return message
return None