Spaces:
Running
Running
from typing import Any | |
from urllib.parse import urljoin | |
import httpx | |
from langchain_ollama import ChatOllama | |
from langflow.base.models.model import LCModelComponent | |
from langflow.field_typing import LanguageModel | |
from langflow.inputs.inputs import HandleInput | |
from langflow.io import BoolInput, DictInput, DropdownInput, FloatInput, IntInput, StrInput | |
class ChatOllamaComponent(LCModelComponent): | |
display_name = "Ollama" | |
description = "Generate text using Ollama Local LLMs." | |
icon = "Ollama" | |
name = "OllamaModel" | |
def update_build_config(self, build_config: dict, field_value: Any, field_name: str | None = None): | |
if field_name == "mirostat": | |
if field_value == "Disabled": | |
build_config["mirostat_eta"]["advanced"] = True | |
build_config["mirostat_tau"]["advanced"] = True | |
build_config["mirostat_eta"]["value"] = None | |
build_config["mirostat_tau"]["value"] = None | |
else: | |
build_config["mirostat_eta"]["advanced"] = False | |
build_config["mirostat_tau"]["advanced"] = False | |
if field_value == "Mirostat 2.0": | |
build_config["mirostat_eta"]["value"] = 0.2 | |
build_config["mirostat_tau"]["value"] = 10 | |
else: | |
build_config["mirostat_eta"]["value"] = 0.1 | |
build_config["mirostat_tau"]["value"] = 5 | |
if field_name == "model_name": | |
base_url_dict = build_config.get("base_url", {}) | |
base_url_load_from_db = base_url_dict.get("load_from_db", False) | |
base_url_value = base_url_dict.get("value") | |
if base_url_load_from_db: | |
base_url_value = self.variables(base_url_value, field_name) | |
elif not base_url_value: | |
base_url_value = "http://localhost:11434" | |
build_config["model_name"]["options"] = self.get_model(base_url_value) | |
if field_name == "keep_alive_flag": | |
if field_value == "Keep": | |
build_config["keep_alive"]["value"] = "-1" | |
build_config["keep_alive"]["advanced"] = True | |
elif field_value == "Immediately": | |
build_config["keep_alive"]["value"] = "0" | |
build_config["keep_alive"]["advanced"] = True | |
else: | |
build_config["keep_alive"]["advanced"] = False | |
return build_config | |
def get_model(self, base_url_value: str) -> list[str]: | |
try: | |
url = urljoin(base_url_value, "/api/tags") | |
with httpx.Client() as client: | |
response = client.get(url) | |
response.raise_for_status() | |
data = response.json() | |
return [model["name"] for model in data.get("models", [])] | |
except Exception as e: | |
msg = "Could not retrieve models. Please, make sure Ollama is running." | |
raise ValueError(msg) from e | |
inputs = [ | |
StrInput( | |
name="base_url", | |
display_name="Base URL", | |
info="Endpoint of the Ollama API. Defaults to 'http://localhost:11434' if not specified.", | |
value="http://localhost:11434", | |
), | |
DropdownInput( | |
name="model_name", | |
display_name="Model Name", | |
value="llama3.1", | |
info="Refer to https://ollama.com/library for more models.", | |
refresh_button=True, | |
), | |
FloatInput( | |
name="temperature", | |
display_name="Temperature", | |
value=0.2, | |
info="Controls the creativity of model responses.", | |
), | |
StrInput( | |
name="format", display_name="Format", info="Specify the format of the output (e.g., json).", advanced=True | |
), | |
DictInput(name="metadata", display_name="Metadata", info="Metadata to add to the run trace.", advanced=True), | |
DropdownInput( | |
name="mirostat", | |
display_name="Mirostat", | |
options=["Disabled", "Mirostat", "Mirostat 2.0"], | |
info="Enable/disable Mirostat sampling for controlling perplexity.", | |
value="Disabled", | |
advanced=True, | |
real_time_refresh=True, | |
), | |
FloatInput( | |
name="mirostat_eta", | |
display_name="Mirostat Eta", | |
info="Learning rate for Mirostat algorithm. (Default: 0.1)", | |
advanced=True, | |
), | |
FloatInput( | |
name="mirostat_tau", | |
display_name="Mirostat Tau", | |
info="Controls the balance between coherence and diversity of the output. (Default: 5.0)", | |
advanced=True, | |
), | |
IntInput( | |
name="num_ctx", | |
display_name="Context Window Size", | |
info="Size of the context window for generating tokens. (Default: 2048)", | |
advanced=True, | |
), | |
IntInput( | |
name="num_gpu", | |
display_name="Number of GPUs", | |
info="Number of GPUs to use for computation. (Default: 1 on macOS, 0 to disable)", | |
advanced=True, | |
), | |
IntInput( | |
name="num_thread", | |
display_name="Number of Threads", | |
info="Number of threads to use during computation. (Default: detected for optimal performance)", | |
advanced=True, | |
), | |
IntInput( | |
name="repeat_last_n", | |
display_name="Repeat Last N", | |
info="How far back the model looks to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)", | |
advanced=True, | |
), | |
FloatInput( | |
name="repeat_penalty", | |
display_name="Repeat Penalty", | |
info="Penalty for repetitions in generated text. (Default: 1.1)", | |
advanced=True, | |
), | |
FloatInput(name="tfs_z", display_name="TFS Z", info="Tail free sampling value. (Default: 1)", advanced=True), | |
IntInput(name="timeout", display_name="Timeout", info="Timeout for the request stream.", advanced=True), | |
IntInput( | |
name="top_k", display_name="Top K", info="Limits token selection to top K. (Default: 40)", advanced=True | |
), | |
FloatInput(name="top_p", display_name="Top P", info="Works together with top-k. (Default: 0.9)", advanced=True), | |
BoolInput(name="verbose", display_name="Verbose", info="Whether to print out response text.", advanced=True), | |
StrInput( | |
name="tags", | |
display_name="Tags", | |
info="Comma-separated list of tags to add to the run trace.", | |
advanced=True, | |
), | |
StrInput( | |
name="stop_tokens", | |
display_name="Stop Tokens", | |
info="Comma-separated list of tokens to signal the model to stop generating text.", | |
advanced=True, | |
), | |
StrInput(name="system", display_name="System", info="System to use for generating text.", advanced=True), | |
StrInput(name="template", display_name="Template", info="Template to use for generating text.", advanced=True), | |
HandleInput( | |
name="output_parser", | |
display_name="Output Parser", | |
info="The parser to use to parse the output of the model", | |
advanced=True, | |
input_types=["OutputParser"], | |
), | |
*LCModelComponent._base_inputs, | |
] | |
def build_model(self) -> LanguageModel: # type: ignore[type-var] | |
# Mapping mirostat settings to their corresponding values | |
mirostat_options = {"Mirostat": 1, "Mirostat 2.0": 2} | |
# Default to 0 for 'Disabled' | |
mirostat_value = mirostat_options.get(self.mirostat, 0) | |
# Set mirostat_eta and mirostat_tau to None if mirostat is disabled | |
if mirostat_value == 0: | |
mirostat_eta = None | |
mirostat_tau = None | |
else: | |
mirostat_eta = self.mirostat_eta | |
mirostat_tau = self.mirostat_tau | |
# Mapping system settings to their corresponding values | |
llm_params = { | |
"base_url": self.base_url, | |
"model": self.model_name, | |
"mirostat": mirostat_value, | |
"format": self.format, | |
"metadata": self.metadata, | |
"tags": self.tags.split(",") if self.tags else None, | |
"mirostat_eta": mirostat_eta, | |
"mirostat_tau": mirostat_tau, | |
"num_ctx": self.num_ctx or None, | |
"num_gpu": self.num_gpu or None, | |
"num_thread": self.num_thread or None, | |
"repeat_last_n": self.repeat_last_n or None, | |
"repeat_penalty": self.repeat_penalty or None, | |
"temperature": self.temperature or None, | |
"stop": self.stop_tokens.split(",") if self.stop_tokens else None, | |
"system": self.system, | |
"template": self.template, | |
"tfs_z": self.tfs_z or None, | |
"timeout": self.timeout or None, | |
"top_k": self.top_k or None, | |
"top_p": self.top_p or None, | |
"verbose": self.verbose, | |
} | |
# Remove parameters with None values | |
llm_params = {k: v for k, v in llm_params.items() if v is not None} | |
try: | |
output = ChatOllama(**llm_params) | |
except Exception as e: | |
msg = "Could not initialize Ollama LLM." | |
raise ValueError(msg) from e | |
return output | |