Spaces:
Running
Running
from typing import Any | |
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
# TODO: langchain_community.llms.huggingface_endpoint is depreciated. | |
# Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0 | |
from langflow.base.models.model import LCModelComponent | |
from langflow.field_typing import LanguageModel | |
from langflow.inputs.inputs import HandleInput | |
from langflow.io import DictInput, DropdownInput, FloatInput, IntInput, SecretStrInput, StrInput | |
class HuggingFaceEndpointsComponent(LCModelComponent): | |
display_name: str = "HuggingFace" | |
description: str = "Generate text using Hugging Face Inference APIs." | |
icon = "HuggingFace" | |
name = "HuggingFaceModel" | |
inputs = [ | |
*LCModelComponent._base_inputs, | |
StrInput(name="model_id", display_name="Model ID", value="openai-community/gpt2"), | |
IntInput( | |
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens" | |
), | |
IntInput( | |
name="top_k", | |
display_name="Top K", | |
advanced=True, | |
info="The number of highest probability vocabulary tokens to keep for top-k-filtering", | |
), | |
FloatInput( | |
name="top_p", | |
display_name="Top P", | |
value=0.95, | |
advanced=True, | |
info=( | |
"If set to < 1, only the smallest set of most probable tokens with " | |
"probabilities that add up to `top_p` or higher are kept for generation" | |
), | |
), | |
FloatInput( | |
name="typical_p", | |
display_name="Typical P", | |
value=0.95, | |
advanced=True, | |
info="Typical Decoding mass.", | |
), | |
FloatInput( | |
name="temperature", | |
display_name="Temperature", | |
value=0.8, | |
advanced=True, | |
info="The value used to module the logits distribution", | |
), | |
FloatInput( | |
name="repetition_penalty", | |
display_name="Repetition Penalty", | |
info="The parameter for repetition penalty. 1.0 means no penalty.", | |
advanced=True, | |
), | |
StrInput( | |
name="inference_endpoint", | |
display_name="Inference Endpoint", | |
value="https://api-inference.huggingface.co/models/", | |
info="Custom inference endpoint URL.", | |
), | |
DropdownInput( | |
name="task", | |
display_name="Task", | |
options=["text2text-generation", "text-generation", "summarization", "translation"], | |
advanced=True, | |
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.", | |
), | |
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True), | |
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True), | |
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True), | |
HandleInput( | |
name="output_parser", | |
display_name="Output Parser", | |
info="The parser to use to parse the output of the model", | |
advanced=True, | |
input_types=["OutputParser"], | |
), | |
] | |
def get_api_url(self) -> str: | |
if "huggingface" in self.inference_endpoint.lower(): | |
return f"{self.inference_endpoint}{self.model_id}" | |
return self.inference_endpoint | |
def create_huggingface_endpoint( | |
self, | |
task: str | None, | |
huggingfacehub_api_token: str | None, | |
model_kwargs: dict[str, Any], | |
max_new_tokens: int, | |
top_k: int | None, | |
top_p: float, | |
typical_p: float | None, | |
temperature: float | None, | |
repetition_penalty: float | None, | |
) -> HuggingFaceEndpoint: | |
retry_attempts = self.retry_attempts | |
endpoint_url = self.get_api_url() | |
def _attempt_create(): | |
return HuggingFaceEndpoint( | |
endpoint_url=endpoint_url, | |
task=task, | |
huggingfacehub_api_token=huggingfacehub_api_token, | |
model_kwargs=model_kwargs, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
typical_p=typical_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
return _attempt_create() | |
def build_model(self) -> LanguageModel: | |
task = self.task or None | |
huggingfacehub_api_token = self.huggingfacehub_api_token | |
model_kwargs = self.model_kwargs or {} | |
max_new_tokens = self.max_new_tokens | |
top_k = self.top_k or None | |
top_p = self.top_p | |
typical_p = self.typical_p or None | |
temperature = self.temperature or 0.8 | |
repetition_penalty = self.repetition_penalty or None | |
try: | |
llm = self.create_huggingface_endpoint( | |
task=task, | |
huggingfacehub_api_token=huggingfacehub_api_token, | |
model_kwargs=model_kwargs, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
typical_p=typical_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
) | |
except Exception as e: | |
msg = "Could not connect to HuggingFace Endpoints API." | |
raise ValueError(msg) from e | |
return llm | |