Tai Truong
fix readme
d202ada
# from langflow.field_typing import Data
import numpy as np
# TODO: remove ignore once the google package is published with types
from google.ai.generativelanguage_v1beta.types import BatchEmbedContentsRequest
from langchain_core.embeddings import Embeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_google_genai._common import GoogleGenerativeAIError
from langflow.custom import Component
from langflow.io import MessageTextInput, Output, SecretStrInput
class GoogleGenerativeAIEmbeddingsComponent(Component):
display_name = "Google Generative AI Embeddings"
description = (
"Connect to Google's generative AI embeddings service using the GoogleGenerativeAIEmbeddings class, "
"found in the langchain-google-genai package."
)
documentation: str = "https://python.langchain.com/v0.2/docs/integrations/text_embedding/google_generative_ai/"
icon = "Google"
name = "Google Generative AI Embeddings"
inputs = [
SecretStrInput(name="api_key", display_name="API Key"),
MessageTextInput(name="model_name", display_name="Model Name", value="models/text-embedding-004"),
]
outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]
def build_embeddings(self) -> Embeddings:
if not self.api_key:
msg = "API Key is required"
raise ValueError(msg)
class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings):
def __init__(self, *args, **kwargs) -> None:
super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs)
def embed_documents(
self,
texts: list[str],
*,
batch_size: int = 100,
task_type: str | None = None,
titles: list[str] | None = None,
output_dimensionality: int | None = 1536,
) -> list[list[float]]:
"""Embed a list of strings.
Google Generative AI currently sets a max batch size of 100 strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
titles: An optional list of titles for texts provided.
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
output_dimensionality: Optional reduced dimension for the output embedding.
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
Returns:
List of embeddings, one for each text.
"""
embeddings: list[list[float]] = []
batch_start_index = 0
for batch in GoogleGenerativeAIEmbeddings._prepare_batches(texts, batch_size):
if titles:
titles_batch = titles[batch_start_index : batch_start_index + len(batch)]
batch_start_index += len(batch)
else:
titles_batch = [None] * len(batch) # type: ignore[list-item]
requests = [
self._prepare_request(
text=text,
task_type=task_type,
title=title,
output_dimensionality=output_dimensionality,
)
for text, title in zip(batch, titles_batch, strict=True)
]
try:
result = self.client.batch_embed_contents(
BatchEmbedContentsRequest(requests=requests, model=self.model)
)
except Exception as e:
msg = f"Error embedding content: {e}"
raise GoogleGenerativeAIError(msg) from e
embeddings.extend([list(np.pad(e.values, (0, 768), "constant")) for e in result.embeddings])
return embeddings
def embed_query(
self,
text: str,
task_type: str | None = None,
title: str | None = None,
output_dimensionality: int | None = 1536,
) -> list[float]:
"""Embed a text.
Args:
text: The text to embed.
task_type: task_type (https://ai.google.dev/api/rest/v1/TaskType)
title: An optional title for the text.
Only applicable when TaskType is RETRIEVAL_DOCUMENT.
output_dimensionality: Optional reduced dimension for the output embedding.
https://ai.google.dev/api/rest/v1/models/batchEmbedContents#EmbedContentRequest
Returns:
Embedding for the text.
"""
task_type = task_type or "RETRIEVAL_QUERY"
return self.embed_documents(
[text],
task_type=task_type,
titles=[title] if title else None,
output_dimensionality=output_dimensionality,
)[0]
return HotaGoogleGenerativeAIEmbeddings(model=self.model_name, google_api_key=self.api_key)