Tai Truong
fix readme
d202ada
import dataclasses
import os
import uuid
from typing import Any
import requests
from astrapy.admin import parse_api_endpoint
from langflow.api.v1.schemas import InputValueRequest
from langflow.custom import Component
from langflow.custom.eval import eval_custom_component_code
from langflow.field_typing import Embeddings
from langflow.graph import Graph
from langflow.processing.process import run_graph_internal
def check_env_vars(*env_vars):
"""Check if all specified environment variables are set.
Args:
*env_vars (str): The environment variables to check.
Returns:
bool: True if all environment variables are set, False otherwise.
"""
return all(os.getenv(var) for var in env_vars)
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
"""Check if the specified region is valid.
Args:
api_endpoint: The API endpoint to check.
Returns:
True if the region contains hosted nvidia models, False otherwise.
"""
parsed_endpoint = parse_api_endpoint(api_endpoint)
if not parsed_endpoint:
msg = "Invalid ASTRA_DB_API_ENDPOINT"
raise ValueError(msg)
return parsed_endpoint.region == "us-east-2"
class MockEmbeddings(Embeddings):
def __init__(self):
self.embedded_documents = None
self.embedded_query = None
@staticmethod
def mock_embedding(text: str):
return [len(text) / 2, len(text) / 5, len(text) / 10]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
self.embedded_documents = texts
return [self.mock_embedding(text) for text in texts]
def embed_query(self, text: str) -> list[float]:
self.embedded_query = text
return self.mock_embedding(text)
@dataclasses.dataclass
class JSONFlow:
json: dict
def get_components_by_type(self, component_type):
result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type]
if not result:
msg = (
f"Component of type {component_type} not found, "
f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
)
raise ValueError(msg)
return result
def get_component_by_type(self, component_type):
components = self.get_components_by_type(component_type)
if len(components) > 1:
msg = f"Multiple components of type {component_type} found"
raise ValueError(msg)
return components[0]
def set_value(self, component_id, key, value):
done = False
for node in self.json["data"]["nodes"]:
if node["id"] == component_id:
if key not in node["data"]["node"]["template"]:
msg = f"Component {component_id} does not have input {key}"
raise ValueError(msg)
node["data"]["node"]["template"][key]["value"] = value
node["data"]["node"]["template"][key]["load_from_db"] = False
done = True
break
if not done:
msg = f"Component {component_id} not found"
raise ValueError(msg)
def download_flow_from_github(name: str, version: str) -> JSONFlow:
response = requests.get(
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json",
timeout=10,
)
response.raise_for_status()
as_json = response.json()
return JSONFlow(json=as_json)
def download_component_from_github(module: str, file_name: str, version: str) -> Component:
version_string = f"v{version}" if version != "main" else version
response = requests.get(
f"https://raw.githubusercontent.com/langflow-ai/langflow/{version_string}/src/backend/base/langflow/components/{module}/{file_name}.py",
timeout=10,
)
response.raise_for_status()
return Component(_code=response.text)
async def run_json_flow(
json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None
) -> dict[str, Any]:
graph = Graph.from_payload(json_flow.json)
return await run_flow(graph, run_input, session_id)
async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]:
graph.prepare()
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else []
flow_id = str(uuid.uuid4())
results, _ = await run_graph_internal(graph, flow_id, session_id=session_id, inputs=graph_run_inputs)
outputs = {}
for r in results:
for out in r.outputs:
outputs |= out.results
return outputs
@dataclasses.dataclass
class ComponentInputHandle:
clazz: type
inputs: dict
output_name: str
async def run_single_component(
clazz: type,
inputs: dict | None = None,
run_input: Any | None = None,
session_id: str | None = None,
input_type: str | None = "chat",
) -> dict[str, Any]:
user_id = str(uuid.uuid4())
flow_id = str(uuid.uuid4())
graph = Graph(user_id=user_id, flow_id=flow_id)
def _add_component(clazz: type, inputs: dict | None = None) -> str:
raw_inputs = {}
if inputs:
for key, value in inputs.items():
if not isinstance(value, ComponentInputHandle):
raw_inputs[key] = value
if isinstance(value, Component):
msg = "Component inputs must be wrapped in ComponentInputHandle"
raise TypeError(msg)
component = clazz(**raw_inputs, _user_id=user_id)
component_id = graph.add_component(component)
if inputs:
for input_name, handle in inputs.items():
if isinstance(handle, ComponentInputHandle):
handle_component_id = _add_component(handle.clazz, handle.inputs)
graph.add_component_edge(handle_component_id, (handle.output_name, input_name), component_id)
return component_id
component_id = _add_component(clazz, inputs)
graph.prepare()
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else []
_, _ = await run_graph_internal(
graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id]
)
return graph.get_vertex(component_id).built_object
def build_component_instance_for_tests(version: str, module: str, file_name: str, **kwargs):
component = download_component_from_github(module, file_name, version)
cc_class = eval_custom_component_code(component._code)
return cc_class(**kwargs), component._code