Spaces:
Running
Running
import dataclasses | |
import os | |
import uuid | |
from typing import Any | |
import requests | |
from astrapy.admin import parse_api_endpoint | |
from langflow.api.v1.schemas import InputValueRequest | |
from langflow.custom import Component | |
from langflow.custom.eval import eval_custom_component_code | |
from langflow.field_typing import Embeddings | |
from langflow.graph import Graph | |
from langflow.processing.process import run_graph_internal | |
def check_env_vars(*env_vars): | |
"""Check if all specified environment variables are set. | |
Args: | |
*env_vars (str): The environment variables to check. | |
Returns: | |
bool: True if all environment variables are set, False otherwise. | |
""" | |
return all(os.getenv(var) for var in env_vars) | |
def valid_nvidia_vectorize_region(api_endpoint: str) -> bool: | |
"""Check if the specified region is valid. | |
Args: | |
api_endpoint: The API endpoint to check. | |
Returns: | |
True if the region contains hosted nvidia models, False otherwise. | |
""" | |
parsed_endpoint = parse_api_endpoint(api_endpoint) | |
if not parsed_endpoint: | |
msg = "Invalid ASTRA_DB_API_ENDPOINT" | |
raise ValueError(msg) | |
return parsed_endpoint.region == "us-east-2" | |
class MockEmbeddings(Embeddings): | |
def __init__(self): | |
self.embedded_documents = None | |
self.embedded_query = None | |
def mock_embedding(text: str): | |
return [len(text) / 2, len(text) / 5, len(text) / 10] | |
def embed_documents(self, texts: list[str]) -> list[list[float]]: | |
self.embedded_documents = texts | |
return [self.mock_embedding(text) for text in texts] | |
def embed_query(self, text: str) -> list[float]: | |
self.embedded_query = text | |
return self.mock_embedding(text) | |
class JSONFlow: | |
json: dict | |
def get_components_by_type(self, component_type): | |
result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type] | |
if not result: | |
msg = ( | |
f"Component of type {component_type} not found, " | |
f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}" | |
) | |
raise ValueError(msg) | |
return result | |
def get_component_by_type(self, component_type): | |
components = self.get_components_by_type(component_type) | |
if len(components) > 1: | |
msg = f"Multiple components of type {component_type} found" | |
raise ValueError(msg) | |
return components[0] | |
def set_value(self, component_id, key, value): | |
done = False | |
for node in self.json["data"]["nodes"]: | |
if node["id"] == component_id: | |
if key not in node["data"]["node"]["template"]: | |
msg = f"Component {component_id} does not have input {key}" | |
raise ValueError(msg) | |
node["data"]["node"]["template"][key]["value"] = value | |
node["data"]["node"]["template"][key]["load_from_db"] = False | |
done = True | |
break | |
if not done: | |
msg = f"Component {component_id} not found" | |
raise ValueError(msg) | |
def download_flow_from_github(name: str, version: str) -> JSONFlow: | |
response = requests.get( | |
f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json", | |
timeout=10, | |
) | |
response.raise_for_status() | |
as_json = response.json() | |
return JSONFlow(json=as_json) | |
def download_component_from_github(module: str, file_name: str, version: str) -> Component: | |
version_string = f"v{version}" if version != "main" else version | |
response = requests.get( | |
f"https://raw.githubusercontent.com/langflow-ai/langflow/{version_string}/src/backend/base/langflow/components/{module}/{file_name}.py", | |
timeout=10, | |
) | |
response.raise_for_status() | |
return Component(_code=response.text) | |
async def run_json_flow( | |
json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None | |
) -> dict[str, Any]: | |
graph = Graph.from_payload(json_flow.json) | |
return await run_flow(graph, run_input, session_id) | |
async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]: | |
graph.prepare() | |
graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else [] | |
flow_id = str(uuid.uuid4()) | |
results, _ = await run_graph_internal(graph, flow_id, session_id=session_id, inputs=graph_run_inputs) | |
outputs = {} | |
for r in results: | |
for out in r.outputs: | |
outputs |= out.results | |
return outputs | |
class ComponentInputHandle: | |
clazz: type | |
inputs: dict | |
output_name: str | |
async def run_single_component( | |
clazz: type, | |
inputs: dict | None = None, | |
run_input: Any | None = None, | |
session_id: str | None = None, | |
input_type: str | None = "chat", | |
) -> dict[str, Any]: | |
user_id = str(uuid.uuid4()) | |
flow_id = str(uuid.uuid4()) | |
graph = Graph(user_id=user_id, flow_id=flow_id) | |
def _add_component(clazz: type, inputs: dict | None = None) -> str: | |
raw_inputs = {} | |
if inputs: | |
for key, value in inputs.items(): | |
if not isinstance(value, ComponentInputHandle): | |
raw_inputs[key] = value | |
if isinstance(value, Component): | |
msg = "Component inputs must be wrapped in ComponentInputHandle" | |
raise TypeError(msg) | |
component = clazz(**raw_inputs, _user_id=user_id) | |
component_id = graph.add_component(component) | |
if inputs: | |
for input_name, handle in inputs.items(): | |
if isinstance(handle, ComponentInputHandle): | |
handle_component_id = _add_component(handle.clazz, handle.inputs) | |
graph.add_component_edge(handle_component_id, (handle.output_name, input_name), component_id) | |
return component_id | |
component_id = _add_component(clazz, inputs) | |
graph.prepare() | |
graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else [] | |
_, _ = await run_graph_internal( | |
graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id] | |
) | |
return graph.get_vertex(component_id).built_object | |
def build_component_instance_for_tests(version: str, module: str, file_name: str, **kwargs): | |
component = download_component_from_github(module, file_name, version) | |
cc_class = eval_custom_component_code(component._code) | |
return cc_class(**kwargs), component._code | |