import dataclasses
import os
import uuid
from typing import Any

import requests
from astrapy.admin import parse_api_endpoint
from langflow.api.v1.schemas import InputValueRequest
from langflow.custom import Component
from langflow.custom.eval import eval_custom_component_code
from langflow.field_typing import Embeddings
from langflow.graph import Graph
from langflow.processing.process import run_graph_internal


def check_env_vars(*env_vars):
    """Check if all specified environment variables are set.

    Args:
    *env_vars (str): The environment variables to check.

    Returns:
    bool: True if all environment variables are set, False otherwise.
    """
    return all(os.getenv(var) for var in env_vars)


def valid_nvidia_vectorize_region(api_endpoint: str) -> bool:
    """Check if the specified region is valid.

    Args:
        api_endpoint: The API endpoint to check.

    Returns:
        True if the region contains hosted nvidia models, False otherwise.
    """
    parsed_endpoint = parse_api_endpoint(api_endpoint)
    if not parsed_endpoint:
        msg = "Invalid ASTRA_DB_API_ENDPOINT"
        raise ValueError(msg)
    return parsed_endpoint.region == "us-east-2"


class MockEmbeddings(Embeddings):
    def __init__(self):
        self.embedded_documents = None
        self.embedded_query = None

    @staticmethod
    def mock_embedding(text: str):
        return [len(text) / 2, len(text) / 5, len(text) / 10]

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        self.embedded_documents = texts
        return [self.mock_embedding(text) for text in texts]

    def embed_query(self, text: str) -> list[float]:
        self.embedded_query = text
        return self.mock_embedding(text)


@dataclasses.dataclass
class JSONFlow:
    json: dict

    def get_components_by_type(self, component_type):
        result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type]
        if not result:
            msg = (
                f"Component of type {component_type} not found, "
                f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}"
            )
            raise ValueError(msg)
        return result

    def get_component_by_type(self, component_type):
        components = self.get_components_by_type(component_type)
        if len(components) > 1:
            msg = f"Multiple components of type {component_type} found"
            raise ValueError(msg)
        return components[0]

    def set_value(self, component_id, key, value):
        done = False
        for node in self.json["data"]["nodes"]:
            if node["id"] == component_id:
                if key not in node["data"]["node"]["template"]:
                    msg = f"Component {component_id} does not have input {key}"
                    raise ValueError(msg)
                node["data"]["node"]["template"][key]["value"] = value
                node["data"]["node"]["template"][key]["load_from_db"] = False
                done = True
                break
        if not done:
            msg = f"Component {component_id} not found"
            raise ValueError(msg)


def download_flow_from_github(name: str, version: str) -> JSONFlow:
    response = requests.get(
        f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json",
        timeout=10,
    )
    response.raise_for_status()
    as_json = response.json()
    return JSONFlow(json=as_json)


def download_component_from_github(module: str, file_name: str, version: str) -> Component:
    version_string = f"v{version}" if version != "main" else version
    response = requests.get(
        f"https://raw.githubusercontent.com/langflow-ai/langflow/{version_string}/src/backend/base/langflow/components/{module}/{file_name}.py",
        timeout=10,
    )
    response.raise_for_status()
    return Component(_code=response.text)


async def run_json_flow(
    json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None
) -> dict[str, Any]:
    graph = Graph.from_payload(json_flow.json)
    return await run_flow(graph, run_input, session_id)


async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]:
    graph.prepare()
    graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else []

    flow_id = str(uuid.uuid4())

    results, _ = await run_graph_internal(graph, flow_id, session_id=session_id, inputs=graph_run_inputs)
    outputs = {}
    for r in results:
        for out in r.outputs:
            outputs |= out.results
    return outputs


@dataclasses.dataclass
class ComponentInputHandle:
    clazz: type
    inputs: dict
    output_name: str


async def run_single_component(
    clazz: type,
    inputs: dict | None = None,
    run_input: Any | None = None,
    session_id: str | None = None,
    input_type: str | None = "chat",
) -> dict[str, Any]:
    user_id = str(uuid.uuid4())
    flow_id = str(uuid.uuid4())
    graph = Graph(user_id=user_id, flow_id=flow_id)

    def _add_component(clazz: type, inputs: dict | None = None) -> str:
        raw_inputs = {}
        if inputs:
            for key, value in inputs.items():
                if not isinstance(value, ComponentInputHandle):
                    raw_inputs[key] = value
                if isinstance(value, Component):
                    msg = "Component inputs must be wrapped in ComponentInputHandle"
                    raise TypeError(msg)
        component = clazz(**raw_inputs, _user_id=user_id)
        component_id = graph.add_component(component)
        if inputs:
            for input_name, handle in inputs.items():
                if isinstance(handle, ComponentInputHandle):
                    handle_component_id = _add_component(handle.clazz, handle.inputs)
                    graph.add_component_edge(handle_component_id, (handle.output_name, input_name), component_id)
        return component_id

    component_id = _add_component(clazz, inputs)
    graph.prepare()
    graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else []

    _, _ = await run_graph_internal(
        graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id]
    )
    return graph.get_vertex(component_id).built_object


def build_component_instance_for_tests(version: str, module: str, file_name: str, **kwargs):
    component = download_component_from_github(module, file_name, version)
    cc_class = eval_custom_component_code(component._code)
    return cc_class(**kwargs), component._code