Spaces:
Running
Running
# from langflow.field_typing import Data | |
import numpy as np | |
# TODO: remove ignore once the google package is published with types | |
from google.ai.generativelanguage_v1beta.types import BatchEmbedContentsRequest | |
from langchain_core.embeddings import Embeddings | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
from langchain_google_genai._common import GoogleGenerativeAIError | |
from langflow.custom import Component | |
from langflow.io import MessageTextInput, Output, SecretStrInput | |
class GoogleGenerativeAIEmbeddingsComponent(Component): | |
display_name = "Google Generative AI Embeddings" | |
description = ( | |
"Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, " | |
"found in the langchain-google-genai package." | |
) | |
documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/" | |
icon = "Google" | |
name = "Google Generative AI Embeddings" | |
inputs = [ | |
SecretStrInput(name="api_key", display_name="API Key"), | |
MessageTextInput(name="model_name", display_name="Model Name", value="models/text-embedding-004"), | |
] | |
outputs = [ | |
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), | |
] | |
def build_embeddings(self) -> Embeddings: | |
if not self.api_key: | |
msg = "API Key is required" | |
raise ValueError(msg) | |
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings): | |
def __init__(self, *args, **kwargs) -> None: | |
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs) | |
def embed_documents( | |
self, | |
texts: list[str], | |
*, | |
batch_size: int = 100, | |
task_type: str | None = None, | |
titles: list[str] | None = None, | |
output_dimensionality: int | None = 1536, | |
) -> list[list[float]]: | |
"""Embed a list of strings. | |
Google Generative AI currently sets a max batch size of 100 strings. | |
Args: | |
texts: List[str] The list of strings to embed. | |
batch_size: [int] The batch size of embeddings to send to the model | |
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) | |
titles: An optional list of titles for texts provided. | |
Only applicable when TaskType is RETRIEVAL_DOCUMENT. | |
output_dimensionality: Optional reduced dimension for the output embedding. | |
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest | |
Returns: | |
List of embeddings, one for each text. | |
""" | |
embeddings: list[list[float]] = [] | |
batch_start_index = 0 | |
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size): | |
if titles: | |
titles_batch = titles[batch_start_index : batch_start_index + len(batch)] | |
batch_start_index += len(batch) | |
else: | |
titles_batch = [None] * len(batch) # type: ignore[list-item] | |
requests = [ | |
self._prepare_request( | |
text=text, | |
task_type=task_type, | |
title=title, | |
output_dimensionality=output_dimensionality, | |
) | |
for text, title in zip(batch, titles_batch, strict=True) | |
] | |
try: | |
result = self.client.batch_embed_contents( | |
BatchEmbedContentsRequest(requests=requests, model=self.model) | |
) | |
except Exception as e: | |
msg = f"Error embedding content: {e}" | |
raise GoogleGenerativeAIError(msg) from e | |
embeddings.extend([list(np.pad(e.values, (0, 768), "constant")) for e in result.embeddings]) | |
return embeddings | |
def embed_query( | |
self, | |
text: str, | |
task_type: str | None = None, | |
title: str | None = None, | |
output_dimensionality: int | None = 1536, | |
) -> list[float]: | |
"""Embed a text. | |
Args: | |
text: The text to embed. | |
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType) | |
title: An optional title for the text. | |
Only applicable when TaskType is RETRIEVAL_DOCUMENT. | |
output_dimensionality: Optional reduced dimension for the output embedding. | |
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest | |
Returns: | |
Embedding for the text. | |
""" | |
task_type = task_type or "RETRIEVAL_QUERY" | |
return self.embed_documents( | |
[text], | |
task_type=task_type, | |
titles=[title] if title else None, | |
output_dimensionality=output_dimensionality, | |
)[0] | |
return HotaGoogleGenerativeAIEmbeddings(model=self.model_name, google_api_key=self.api_key) | |