Tai Truong
fix readme
d202ada
from urllib.parse import urlparse
import requests
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
from pydantic.v1.types import SecretStr
from tenacity import retry, stop_after_attempt, wait_fixed
from langflow.base.embeddings.model import LCEmbeddingsModel
from langflow.field_typing import Embeddings
from langflow.io import MessageTextInput, Output, SecretStrInput
class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
display_name = "HuggingFace Embeddings Inference"
description = "Generate embeddings using HuggingFace Text Embeddings Inference (TEI)"
documentation = "https://huggingface.co/docs/text-embeddings-inference/index"
icon = "HuggingFace"
name = "HuggingFaceInferenceAPIEmbeddings"
inputs = [
SecretStrInput(
name="api_key",
display_name="API Key",
advanced=True,
info="Required for non-local inference endpoints. Local inference does not require an API Key.",
),
MessageTextInput(
name="inference_endpoint",
display_name="Inference Endpoint",
required=True,
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
MessageTextInput(
name="model_name",
display_name="Model Name",
value="BAAI/bge-large-en-v1.5",
info="The name of the model to use for text embeddings.",
),
]
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def validate_inference_endpoint(self, inference_endpoint: str) -> bool:
parsed_url = urlparse(inference_endpoint)
if not all([parsed_url.scheme, parsed_url.netloc]):
msg = (
f"Invalid inference endpoint format: '{self.inference_endpoint}'. "
"Please ensure the URL includes both a scheme (e.g., 'http://' or 'https://') and a domain name. "
"Example: 'http://localhost:8080' or 'https://api.example.com'"
)
raise ValueError(msg)
try:
response = requests.get(f"{inference_endpoint}/health", timeout=5)
except requests.RequestException as e:
msg = (
f"Inference endpoint '{inference_endpoint}' is not responding. "
"Please ensure the URL is correct and the service is running."
)
raise ValueError(msg) from e
if response.status_code != requests.codes.ok:
msg = f"HuggingFace health check failed: {response.status_code}"
raise ValueError(msg)
# returning True to solve linting error
return True
def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_name}"
return self.inference_endpoint
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def create_huggingface_embeddings(
self, api_key: SecretStr, api_url: str, model_name: str
) -> HuggingFaceInferenceAPIEmbeddings:
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name)
def build_embeddings(self) -> Embeddings:
api_url = self.get_api_url()
is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1"))
if not self.api_key and is_local_url:
self.validate_inference_endpoint(api_url)
api_key = SecretStr("DummyAPIKeyForLocalDeployment")
elif not self.api_key:
msg = "API Key is required for non-local inference endpoints"
raise ValueError(msg)
else:
api_key = SecretStr(self.api_key).get_secret_value()
try:
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)
except Exception as e:
msg = "Could not connect to HuggingFace Inference API."
raise ValueError(msg) from e