Spaces:
Running
Running
from urllib.parse import urlparse | |
import requests | |
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings | |
from pydantic.v1.types import SecretStr | |
from tenacity import retry, stop_after_attempt, wait_fixed | |
from langflow.base.embeddings.model import LCEmbeddingsModel | |
from langflow.field_typing import Embeddings | |
from langflow.io import MessageTextInput, Output, SecretStrInput | |
class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel): | |
display_name = "HuggingFace Embeddings Inference" | |
description = "Generate embeddings using HuggingFace Text Embeddings Inference (TEI)" | |
documentation = "https://huggingface.co/docs/text-embeddings-inference/index" | |
icon = "HuggingFace" | |
name = "HuggingFaceInferenceAPIEmbeddings" | |
inputs = [ | |
SecretStrInput( | |
name="api_key", | |
display_name="API Key", | |
advanced=True, | |
info="Required for non-local inference endpoints. Local inference does not require an API Key.", | |
), | |
MessageTextInput( | |
name="inference_endpoint", | |
display_name="Inference Endpoint", | |
required=True, | |
value="https://api-inference.huggingface.co/models/", | |
info="Custom inference endpoint URL.", | |
), | |
MessageTextInput( | |
name="model_name", | |
display_name="Model Name", | |
value="BAAI/bge-large-en-v1.5", | |
info="The name of the model to use for text embeddings.", | |
), | |
] | |
outputs = [ | |
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), | |
] | |
def validate_inference_endpoint(self, inference_endpoint: str) -> bool: | |
parsed_url = urlparse(inference_endpoint) | |
if not all([parsed_url.scheme, parsed_url.netloc]): | |
msg = ( | |
f"Invalid inference endpoint format: '{self.inference_endpoint}'. " | |
"Please ensure the URL includes both a scheme (e.g., 'http://' or 'https://') and a domain name. " | |
"Example: 'http://localhost:8080' or 'https://api.example.com'" | |
) | |
raise ValueError(msg) | |
try: | |
response = requests.get(f"{inference_endpoint}/health", timeout=5) | |
except requests.RequestException as e: | |
msg = ( | |
f"Inference endpoint '{inference_endpoint}' is not responding. " | |
"Please ensure the URL is correct and the service is running." | |
) | |
raise ValueError(msg) from e | |
if response.status_code != requests.codes.ok: | |
msg = f"HuggingFace health check failed: {response.status_code}" | |
raise ValueError(msg) | |
# returning True to solve linting error | |
return True | |
def get_api_url(self) -> str: | |
if "huggingface" in self.inference_endpoint.lower(): | |
return f"{self.inference_endpoint}{self.model_name}" | |
return self.inference_endpoint | |
def create_huggingface_embeddings( | |
self, api_key: SecretStr, api_url: str, model_name: str | |
) -> HuggingFaceInferenceAPIEmbeddings: | |
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name) | |
def build_embeddings(self) -> Embeddings: | |
api_url = self.get_api_url() | |
is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1")) | |
if not self.api_key and is_local_url: | |
self.validate_inference_endpoint(api_url) | |
api_key = SecretStr("DummyAPIKeyForLocalDeployment") | |
elif not self.api_key: | |
msg = "API Key is required for non-local inference endpoints" | |
raise ValueError(msg) | |
else: | |
api_key = SecretStr(self.api_key).get_secret_value() | |
try: | |
return self.create_huggingface_embeddings(api_key, api_url, self.model_name) | |
except Exception as e: | |
msg = "Could not connect to HuggingFace Inference API." | |
raise ValueError(msg) from e | |