Tai Truong
fix readme
d202ada
raw
history blame
5.75 kB
from typing import Any
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from tenacity import retry, stop_after_attempt, wait_fixed
# TODO: langchain_community.llms.huggingface_endpoint is depreciated.
# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.inputs.inputs import HandleInput
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput
class HuggingFaceEndpointsComponent(LCModelComponent):
display_name: str = "HuggingFace"
description: str = "Generate text using Hugging Face Inference APIs."
icon = "HuggingFace"
name = "HuggingFaceModel"
inputs = [
*LCModelComponent._base_inputs,
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"),
IntInput(
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
),
IntInput(
name="top_k",
display_name="Top K",
advanced=True,
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
info=(
"If set to < 1, only the smallest set of most probable tokens with "
"probabilities that add up to `top_p` or higher are kept for generation"
),
),
FloatInput(
name="typical_p",
display_name="Typical P",
value=0.95,
advanced=True,
info="Typical Decoding mass.",
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.8,
advanced=True,
info="The value used to module the logits distribution",
),
FloatInput(
name="repetition_penalty",
display_name="Repetition Penalty",
info="The parameter for repetition penalty. 1.0 means no penalty.",
advanced=True,
),
StrInput(
name="inference_endpoint",
display_name="Inference Endpoint",
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
DropdownInput(
name="task",
display_name="Task",
options=["text2text-generation", "text-generation", "summarization", "translation"],
advanced=True,
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.",
),
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True),
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True),
HandleInput(
name="output_parser",
display_name="Output Parser",
info="The parser to use to parse the output of the model",
advanced=True,
input_types=["OutputParser"],
),
]
def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_id}"
return self.inference_endpoint
def create_huggingface_endpoint(
self,
task: str | None,
huggingfacehub_api_token: str | None,
model_kwargs: dict[str, Any],
max_new_tokens: int,
top_k: int | None,
top_p: float,
typical_p: float | None,
temperature: float | None,
repetition_penalty: float | None,
) -> HuggingFaceEndpoint:
retry_attempts = self.retry_attempts
endpoint_url = self.get_api_url()
@retry(stop=stop_after_attempt(retry_attempts), wait=wait_fixed(2))
def _attempt_create():
return HuggingFaceEndpoint(
endpoint_url=endpoint_url,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
return _attempt_create()
def build_model(self) -> LanguageModel:
task = self.task or None
huggingfacehub_api_token = self.huggingfacehub_api_token
model_kwargs = self.model_kwargs or {}
max_new_tokens = self.max_new_tokens
top_k = self.top_k or None
top_p = self.top_p
typical_p = self.typical_p or None
temperature = self.temperature or 0.8
repetition_penalty = self.repetition_penalty or None
try:
llm = self.create_huggingface_endpoint(
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
except Exception as e:
msg = "Could not connect to HuggingFace Endpoints API."
raise ValueError(msg) from e
return llm