Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| API Client Framework for api.comfy.org. | |
| This module provides a flexible framework for making API requests from ComfyUI nodes. | |
| It supports both synchronous and asynchronous API operations with proper type validation. | |
| Key Components: | |
| -------------- | |
| 1. ApiClient - Handles HTTP requests with authentication and error handling | |
| 2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models | |
| 3. ApiOperation - Executes a single synchronous API operation | |
| Usage Examples: | |
| -------------- | |
| # Example 1: Synchronous API Operation | |
| # ------------------------------------ | |
| # For a simple API call that returns the result immediately: | |
| # 1. Create the API client | |
| api_client = ApiClient( | |
| base_url="https://api.example.com", | |
| auth_token="your_auth_token_here", | |
| comfy_api_key="your_comfy_api_key_here", | |
| timeout=30.0, | |
| verify_ssl=True | |
| ) | |
| # 2. Define the endpoint | |
| user_info_endpoint = ApiEndpoint( | |
| path="/v1/users/me", | |
| method=HttpMethod.GET, | |
| request_model=EmptyRequest, # No request body needed | |
| response_model=UserProfile, # Pydantic model for the response | |
| query_params=None | |
| ) | |
| # 3. Create the request object | |
| request = EmptyRequest() | |
| # 4. Create and execute the operation | |
| operation = ApiOperation( | |
| endpoint=user_info_endpoint, | |
| request=request | |
| ) | |
| user_profile = await operation.execute(client=api_client) # Returns immediately with the result | |
| # Example 2: Asynchronous API Operation with Polling | |
| # ------------------------------------------------- | |
| # For an API that starts a task and requires polling for completion: | |
| # 1. Define the endpoints (initial request and polling) | |
| generate_image_endpoint = ApiEndpoint( | |
| path="/v1/images/generate", | |
| method=HttpMethod.POST, | |
| request_model=ImageGenerationRequest, | |
| response_model=TaskCreatedResponse, | |
| query_params=None | |
| ) | |
| check_task_endpoint = ApiEndpoint( | |
| path="/v1/tasks/{task_id}", | |
| method=HttpMethod.GET, | |
| request_model=EmptyRequest, | |
| response_model=ImageGenerationResult, | |
| query_params=None | |
| ) | |
| # 2. Create the request object | |
| request = ImageGenerationRequest( | |
| prompt="a beautiful sunset over mountains", | |
| width=1024, | |
| height=1024, | |
| num_images=1 | |
| ) | |
| # 3. Create and execute the polling operation | |
| operation = PollingOperation( | |
| initial_endpoint=generate_image_endpoint, | |
| initial_request=request, | |
| poll_endpoint=check_task_endpoint, | |
| task_id_field="task_id", | |
| status_field="status", | |
| completed_statuses=["completed"], | |
| failed_statuses=["failed", "error"] | |
| ) | |
| # This will make the initial request and then poll until completion | |
| result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done | |
| """ | |
| from __future__ import annotations | |
| import aiohttp | |
| import asyncio | |
| import logging | |
| import io | |
| import socket | |
| from aiohttp.client_exceptions import ClientError, ClientResponseError | |
| from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple | |
| from enum import Enum | |
| import json | |
| from urllib.parse import urljoin, urlparse | |
| from pydantic import BaseModel, Field | |
| import uuid # For generating unique operation IDs | |
| from server import PromptServer | |
| from comfy.cli_args import args | |
| from comfy import utils | |
| from . import request_logger | |
| T = TypeVar("T", bound=BaseModel) | |
| R = TypeVar("R", bound=BaseModel) | |
| P = TypeVar("P", bound=BaseModel) # For poll response | |
| PROGRESS_BAR_MAX = 100 | |
| class NetworkError(Exception): | |
| """Base exception for network-related errors with diagnostic information.""" | |
| pass | |
| class LocalNetworkError(NetworkError): | |
| """Exception raised when local network connectivity issues are detected.""" | |
| pass | |
| class ApiServerError(NetworkError): | |
| """Exception raised when the API server is unreachable but internet is working.""" | |
| pass | |
| class EmptyRequest(BaseModel): | |
| """Base class for empty request bodies. | |
| For GET requests, fields will be sent as query parameters.""" | |
| pass | |
| class UploadRequest(BaseModel): | |
| file_name: str = Field(..., description="Filename to upload") | |
| content_type: Optional[str] = Field( | |
| None, | |
| description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", | |
| ) | |
| class UploadResponse(BaseModel): | |
| download_url: str = Field(..., description="URL to GET uploaded file") | |
| upload_url: str = Field(..., description="URL to PUT file to upload") | |
| class HttpMethod(str, Enum): | |
| GET = "GET" | |
| POST = "POST" | |
| PUT = "PUT" | |
| DELETE = "DELETE" | |
| PATCH = "PATCH" | |
| class ApiClient: | |
| """ | |
| Client for making HTTP requests to an API with authentication, error handling, and retry logic. | |
| """ | |
| def __init__( | |
| self, | |
| base_url: str, | |
| auth_token: Optional[str] = None, | |
| comfy_api_key: Optional[str] = None, | |
| timeout: float = 3600.0, | |
| verify_ssl: bool = True, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| retry_backoff_factor: float = 2.0, | |
| retry_status_codes: Optional[Tuple[int, ...]] = None, | |
| session: Optional[aiohttp.ClientSession] = None, | |
| ): | |
| self.base_url = base_url | |
| self.auth_token = auth_token | |
| self.comfy_api_key = comfy_api_key | |
| self.timeout = timeout | |
| self.verify_ssl = verify_ssl | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| self.retry_backoff_factor = retry_backoff_factor | |
| # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), | |
| # 500, 502, 503, 504 (Server Errors) | |
| self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) | |
| self._session: Optional[aiohttp.ClientSession] = session | |
| self._owns_session = session is None # Track if we have to close it | |
| def _generate_operation_id(path: str) -> str: | |
| """Generates a unique operation ID for logging.""" | |
| return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" | |
| def _create_json_payload_args( | |
| data: Optional[Dict[str, Any]] = None, | |
| headers: Optional[Dict[str, str]] = None, | |
| ) -> Dict[str, Any]: | |
| return { | |
| "json": data, | |
| "headers": headers, | |
| } | |
| def _create_form_data_args( | |
| self, | |
| data: Dict[str, Any] | None, | |
| files: Dict[str, Any] | None, | |
| headers: Optional[Dict[str, str]] = None, | |
| multipart_parser: Callable | None = None, | |
| ) -> Dict[str, Any]: | |
| if headers and "Content-Type" in headers: | |
| del headers["Content-Type"] | |
| if multipart_parser and data: | |
| data = multipart_parser(data) | |
| form = aiohttp.FormData(default_to_multipart=True) | |
| if data: # regular text fields | |
| for k, v in data.items(): | |
| if v is None: | |
| continue # aiohttp fails to serialize "None" values | |
| # aiohttp expects strings or bytes; convert enums etc. | |
| form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) | |
| if files: | |
| file_iter = files if isinstance(files, list) else files.items() | |
| for field_name, file_obj in file_iter: | |
| if file_obj is None: | |
| continue # aiohttp fails to serialize "None" values | |
| # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple | |
| if isinstance(file_obj, tuple): | |
| filename, file_value, content_type = self._unpack_tuple(file_obj) | |
| else: | |
| file_value = file_obj | |
| filename = getattr(file_obj, "name", field_name) | |
| content_type = "application/octet-stream" | |
| form.add_field( | |
| name=field_name, | |
| value=file_value, | |
| filename=filename, | |
| content_type=content_type, | |
| ) | |
| return {"data": form, "headers": headers or {}} | |
| def _create_urlencoded_form_data_args( | |
| data: Dict[str, Any], | |
| headers: Optional[Dict[str, str]] = None, | |
| ) -> Dict[str, Any]: | |
| headers = headers or {} | |
| headers["Content-Type"] = "application/x-www-form-urlencoded" | |
| return { | |
| "data": data, | |
| "headers": headers, | |
| } | |
| def get_headers(self) -> Dict[str, str]: | |
| """Get headers for API requests, including authentication if available""" | |
| headers = {"Content-Type": "application/json", "Accept": "application/json"} | |
| if self.auth_token: | |
| headers["Authorization"] = f"Bearer {self.auth_token}" | |
| elif self.comfy_api_key: | |
| headers["X-API-KEY"] = self.comfy_api_key | |
| return headers | |
| async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: | |
| """ | |
| Check connectivity to determine if network issues are local or server-related. | |
| Args: | |
| target_url: URL to check connectivity to | |
| Returns: | |
| Dictionary with connectivity status details | |
| """ | |
| results = { | |
| "internet_accessible": False, | |
| "api_accessible": False, | |
| "is_local_issue": False, | |
| "is_api_issue": False, | |
| } | |
| timeout = aiohttp.ClientTimeout(total=5.0) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| try: | |
| async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: | |
| results["internet_accessible"] = resp.status < 500 | |
| except (ClientError, asyncio.TimeoutError, socket.gaierror): | |
| results["is_local_issue"] = True | |
| return results # cannot reach the internet – early exit | |
| # Now check API health endpoint | |
| parsed = urlparse(target_url) | |
| health_url = f"{parsed.scheme}://{parsed.netloc}/health" | |
| try: | |
| async with session.get(health_url, ssl=self.verify_ssl) as resp: | |
| results["api_accessible"] = resp.status < 500 | |
| except ClientError: | |
| pass # leave as False | |
| results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] | |
| return results | |
| async def request( | |
| self, | |
| method: str, | |
| path: str, | |
| params: Optional[Dict[str, Any]] = None, | |
| data: Optional[Dict[str, Any]] = None, | |
| files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, | |
| headers: Optional[Dict[str, str]] = None, | |
| content_type: str = "application/json", | |
| multipart_parser: Callable | None = None, | |
| retry_count: int = 0, # Used internally for tracking retries | |
| ) -> Dict[str, Any]: | |
| """ | |
| Make an HTTP request to the API with automatic retries for transient errors. | |
| Args: | |
| method: HTTP method (GET, POST, etc.) | |
| path: API endpoint path (will be joined with base_url) | |
| params: Query parameters | |
| data: body data | |
| files: Files to upload | |
| headers: Additional headers | |
| content_type: Content type of the request. Defaults to application/json. | |
| retry_count: Internal parameter for tracking retries, do not set manually | |
| Returns: | |
| Parsed JSON response | |
| Raises: | |
| LocalNetworkError: If local network connectivity issues are detected | |
| ApiServerError: If the API server is unreachable but internet is working | |
| Exception: For other request failures | |
| """ | |
| # Build full URL and merge headers | |
| relative_path = path.lstrip("/") | |
| url = urljoin(self.base_url, relative_path) | |
| self._check_auth(self.auth_token, self.comfy_api_key) | |
| request_headers = self.get_headers() | |
| if headers: | |
| request_headers.update(headers) | |
| if files: | |
| request_headers.pop("Content-Type", None) | |
| if params: | |
| params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values | |
| logging.debug(f"[DEBUG] Request Headers: {request_headers}") | |
| logging.debug(f"[DEBUG] Files: {files}") | |
| logging.debug(f"[DEBUG] Params: {params}") | |
| logging.debug(f"[DEBUG] Data: {data}") | |
| if content_type == "application/x-www-form-urlencoded": | |
| payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) | |
| elif content_type == "multipart/form-data": | |
| payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) | |
| else: | |
| payload_args = self._create_json_payload_args(data, request_headers) | |
| operation_id = self._generate_operation_id(path) | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| request_headers=request_headers, | |
| request_params=params, | |
| request_data=data if content_type == "application/json" else "[form-data or other]", | |
| ) | |
| session = await self._get_session() | |
| try: | |
| async with session.request( | |
| method, | |
| url, | |
| params=params, | |
| ssl=self.verify_ssl, | |
| **payload_args, | |
| ) as resp: | |
| if resp.status >= 400: | |
| try: | |
| error_data = await resp.json() | |
| except (aiohttp.ContentTypeError, json.JSONDecodeError): | |
| error_data = await resp.text() | |
| return await self._handle_http_error( | |
| ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), | |
| operation_id, | |
| method, | |
| url, | |
| params, | |
| data, | |
| files, | |
| headers, | |
| content_type, | |
| multipart_parser, | |
| retry_count=retry_count, | |
| response_content=error_data, | |
| ) | |
| # Success – parse JSON (safely) and log | |
| try: | |
| payload = await resp.json() | |
| response_content_to_log = payload | |
| except (aiohttp.ContentTypeError, json.JSONDecodeError): | |
| payload = {} | |
| response_content_to_log = await resp.text() | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=method, | |
| request_url=url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content=response_content_to_log, | |
| ) | |
| return payload | |
| except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: | |
| # Treat as *connection* problem – optionally retry, else escalate | |
| if retry_count < self.max_retries: | |
| delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) | |
| logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, | |
| self.max_retries, str(e)) | |
| await asyncio.sleep(delay) | |
| return await self.request( | |
| method, | |
| path, | |
| params=params, | |
| data=data, | |
| files=files, | |
| headers=headers, | |
| content_type=content_type, | |
| multipart_parser=multipart_parser, | |
| retry_count=retry_count + 1, | |
| ) | |
| # One final connectivity check for diagnostics | |
| connectivity = await self._check_connectivity(self.base_url) | |
| if connectivity["is_local_issue"]: | |
| raise LocalNetworkError( | |
| "Unable to connect to the API server due to local network issues. " | |
| "Please check your internet connection and try again." | |
| ) from e | |
| raise ApiServerError( | |
| f"The API server at {self.base_url} is currently unreachable. " | |
| f"The service may be experiencing issues. Please try again later." | |
| ) from e | |
| def _check_auth(auth_token, comfy_api_key): | |
| """Verify that an auth token is present or comfy_api_key is present""" | |
| if auth_token is None and comfy_api_key is None: | |
| raise Exception("Unauthorized: Please login first to use this node.") | |
| return auth_token or comfy_api_key | |
| async def upload_file( | |
| upload_url: str, | |
| file: io.BytesIO | str, | |
| content_type: str | None = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| retry_backoff_factor: float = 2.0, | |
| ) -> aiohttp.ClientResponse: | |
| """Upload a file to the API with retry logic. | |
| Args: | |
| upload_url: The URL to upload to | |
| file: Either a file path string, BytesIO object, or tuple of (file_path, filename) | |
| content_type: Optional mime type to set for the upload | |
| max_retries: Maximum number of retry attempts | |
| retry_delay: Initial delay between retries in seconds | |
| retry_backoff_factor: Multiplier for the delay after each retry | |
| """ | |
| headers: Dict[str, str] = {} | |
| skip_auto_headers: set[str] = set() | |
| if content_type: | |
| headers["Content-Type"] = content_type | |
| else: | |
| # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. | |
| skip_auto_headers.add("Content-Type") | |
| # Extract file bytes | |
| if isinstance(file, io.BytesIO): | |
| file.seek(0) | |
| data = file.read() | |
| elif isinstance(file, str): | |
| with open(file, "rb") as f: | |
| data = f.read() | |
| else: | |
| raise ValueError("File must be BytesIO or str path") | |
| operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method="PUT", | |
| request_url=upload_url, | |
| request_headers=headers, | |
| request_data=f"[File data {len(data)} bytes]", | |
| ) | |
| delay = retry_delay | |
| for attempt in range(max_retries + 1): | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.put( | |
| upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, | |
| ) as resp: | |
| resp.raise_for_status() | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method="PUT", | |
| request_url=upload_url, | |
| response_status_code=resp.status, | |
| response_headers=dict(resp.headers), | |
| response_content="File uploaded successfully.", | |
| ) | |
| return resp | |
| except (ClientError, asyncio.TimeoutError) as e: | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method="PUT", | |
| request_url=upload_url, | |
| response_status_code=e.status if hasattr(e, "status") else None, | |
| response_headers=dict(e.headers) if getattr(e, "headers") else None, | |
| response_content=None, | |
| error_message=f"{type(e).__name__}: {str(e)}", | |
| ) | |
| if attempt < max_retries: | |
| logging.warning( | |
| "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) | |
| ) | |
| await asyncio.sleep(delay) | |
| delay *= retry_backoff_factor | |
| else: | |
| raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e | |
| async def _handle_http_error( | |
| self, | |
| exc: ClientResponseError, | |
| operation_id: str, | |
| *req_meta, | |
| retry_count: int, | |
| response_content: dict | str = "", | |
| ) -> Dict[str, Any]: | |
| status_code = exc.status | |
| if status_code == 401: | |
| user_friendly = "Unauthorized: Please login first to use this node." | |
| elif status_code == 402: | |
| user_friendly = "Payment Required: Please add credits to your account to use this node." | |
| elif status_code == 409: | |
| user_friendly = "There is a problem with your account. Please contact [email protected]." | |
| elif status_code == 429: | |
| user_friendly = "Rate Limit Exceeded: Please try again later." | |
| else: | |
| if isinstance(response_content, dict): | |
| if "error" in response_content and "message" in response_content["error"]: | |
| user_friendly = f"API Error: {response_content['error']['message']}" | |
| if "type" in response_content["error"]: | |
| user_friendly += f" (Type: {response_content['error']['type']})" | |
| else: # Handle cases where error is just a JSON dict with unknown format | |
| user_friendly = f"API Error: {json.dumps(response_content)}" | |
| else: | |
| if len(response_content) < 200: # Arbitrary limit for display | |
| user_friendly = f"API Error (raw): {response_content}" | |
| else: | |
| user_friendly = f"API Error (raw, status {response_content})" | |
| request_logger.log_request_response( | |
| operation_id=operation_id, | |
| request_method=req_meta[0], | |
| request_url=req_meta[1], | |
| response_status_code=exc.status, | |
| response_headers=dict(req_meta[5]) if req_meta[5] else None, | |
| response_content=response_content, | |
| error_message=f"HTTP Error {exc.status}", | |
| ) | |
| logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") | |
| if response_content: | |
| logging.debug(f"[DEBUG] Response content: {response_content}") | |
| # Retry if eligible | |
| if status_code in self.retry_status_codes and retry_count < self.max_retries: | |
| delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) | |
| logging.warning( | |
| "HTTP error %s. Retrying in %.2fs (%s/%s)", | |
| status_code, | |
| delay, | |
| retry_count + 1, | |
| self.max_retries, | |
| ) | |
| await asyncio.sleep(delay) | |
| return await self.request( | |
| req_meta[0], # method | |
| req_meta[1].replace(self.base_url, ""), # path | |
| params=req_meta[2], | |
| data=req_meta[3], | |
| files=req_meta[4], | |
| headers=req_meta[5], | |
| content_type=req_meta[6], | |
| multipart_parser=req_meta[7], | |
| retry_count=retry_count + 1, | |
| ) | |
| raise Exception(user_friendly) from exc | |
| def _unpack_tuple(t): | |
| """Helper to normalise (filename, file, content_type) tuples.""" | |
| if len(t) == 3: | |
| return t | |
| elif len(t) == 2: | |
| return t[0], t[1], "application/octet-stream" | |
| else: | |
| raise ValueError("files tuple must be (filename, file[, content_type])") | |
| async def _get_session(self) -> aiohttp.ClientSession: | |
| if self._session is None or self._session.closed: | |
| timeout = aiohttp.ClientTimeout(total=self.timeout) | |
| self._session = aiohttp.ClientSession(timeout=timeout) | |
| self._owns_session = True | |
| return self._session | |
| async def close(self) -> None: | |
| if self._owns_session and self._session and not self._session.closed: | |
| await self._session.close() | |
| async def __aenter__(self) -> "ApiClient": | |
| """Allow usage as async‑context‑manager – ensures clean teardown""" | |
| return self | |
| async def __aexit__(self, exc_type, exc, tb): | |
| await self.close() | |
| class ApiEndpoint(Generic[T, R]): | |
| """Defines an API endpoint with its request and response types""" | |
| def __init__( | |
| self, | |
| path: str, | |
| method: HttpMethod, | |
| request_model: Type[T], | |
| response_model: Type[R], | |
| query_params: Optional[Dict[str, Any]] = None, | |
| ): | |
| """Initialize an API endpoint definition. | |
| Args: | |
| path: The URL path for this endpoint, can include placeholders like {id} | |
| method: The HTTP method to use (GET, POST, etc.) | |
| request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint | |
| response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint | |
| query_params: Optional dictionary of query parameters to include in the request | |
| """ | |
| self.path = path | |
| self.method = method | |
| self.request_model = request_model | |
| self.response_model = response_model | |
| self.query_params = query_params or {} | |
| class SynchronousOperation(Generic[T, R]): | |
| """Represents a single synchronous API operation.""" | |
| def __init__( | |
| self, | |
| endpoint: ApiEndpoint[T, R], | |
| request: T, | |
| files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, | |
| api_base: str | None = None, | |
| auth_token: Optional[str] = None, | |
| comfy_api_key: Optional[str] = None, | |
| auth_kwargs: Optional[Dict[str, str]] = None, | |
| timeout: float = 604800.0, | |
| verify_ssl: bool = True, | |
| content_type: str = "application/json", | |
| multipart_parser: Callable | None = None, | |
| max_retries: int = 3, | |
| retry_delay: float = 1.0, | |
| retry_backoff_factor: float = 2.0, | |
| ) -> None: | |
| self.endpoint = endpoint | |
| self.request = request | |
| self.files = files | |
| self.api_base: str = api_base or args.comfy_api_base | |
| self.auth_token = auth_token | |
| self.comfy_api_key = comfy_api_key | |
| if auth_kwargs is not None: | |
| self.auth_token = auth_kwargs.get("auth_token", self.auth_token) | |
| self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) | |
| self.timeout = timeout | |
| self.verify_ssl = verify_ssl | |
| self.content_type = content_type | |
| self.multipart_parser = multipart_parser | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| self.retry_backoff_factor = retry_backoff_factor | |
| async def execute(self, client: Optional[ApiClient] = None) -> R: | |
| owns_client = client is None | |
| if owns_client: | |
| client = ApiClient( | |
| base_url=self.api_base, | |
| auth_token=self.auth_token, | |
| comfy_api_key=self.comfy_api_key, | |
| timeout=self.timeout, | |
| verify_ssl=self.verify_ssl, | |
| max_retries=self.max_retries, | |
| retry_delay=self.retry_delay, | |
| retry_backoff_factor=self.retry_backoff_factor, | |
| ) | |
| try: | |
| request_dict: Optional[Dict[str, Any]] | |
| if isinstance(self.request, EmptyRequest): | |
| request_dict = None | |
| else: | |
| request_dict = self.request.model_dump(exclude_none=True) | |
| for k, v in list(request_dict.items()): | |
| if isinstance(v, Enum): | |
| request_dict[k] = v.value | |
| logging.debug( | |
| f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" | |
| ) | |
| logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") | |
| logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") | |
| response_json = await client.request( | |
| self.endpoint.method.value, | |
| self.endpoint.path, | |
| params=self.endpoint.query_params, | |
| data=request_dict, | |
| files=self.files, | |
| content_type=self.content_type, | |
| multipart_parser=self.multipart_parser, | |
| ) | |
| logging.debug("=" * 50) | |
| logging.debug("[DEBUG] RESPONSE DETAILS:") | |
| logging.debug("[DEBUG] Status Code: 200 (Success)") | |
| logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") | |
| logging.debug("=" * 50) | |
| parsed_response = self.endpoint.response_model.model_validate(response_json) | |
| logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") | |
| return parsed_response | |
| finally: | |
| if owns_client: | |
| await client.close() | |
| class TaskStatus(str, Enum): | |
| """Enum for task status values""" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| PENDING = "pending" | |
| class PollingOperation(Generic[T, R]): | |
| """Represents an asynchronous API operation that requires polling for completion.""" | |
| def __init__( | |
| self, | |
| poll_endpoint: ApiEndpoint[EmptyRequest, R], | |
| completed_statuses: list[str], | |
| failed_statuses: list[str], | |
| status_extractor: Callable[[R], str], | |
| progress_extractor: Callable[[R], float] | None = None, | |
| result_url_extractor: Callable[[R], str] | None = None, | |
| request: Optional[T] = None, | |
| api_base: str | None = None, | |
| auth_token: Optional[str] = None, | |
| comfy_api_key: Optional[str] = None, | |
| auth_kwargs: Optional[Dict[str, str]] = None, | |
| poll_interval: float = 5.0, | |
| max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) | |
| max_retries: int = 3, # Max retries per individual API call | |
| retry_delay: float = 1.0, | |
| retry_backoff_factor: float = 2.0, | |
| estimated_duration: Optional[float] = None, | |
| node_id: Optional[str] = None, | |
| ) -> None: | |
| self.poll_endpoint = poll_endpoint | |
| self.request = request | |
| self.api_base: str = api_base or args.comfy_api_base | |
| self.auth_token = auth_token | |
| self.comfy_api_key = comfy_api_key | |
| if auth_kwargs is not None: | |
| self.auth_token = auth_kwargs.get("auth_token", self.auth_token) | |
| self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) | |
| self.poll_interval = poll_interval | |
| self.max_poll_attempts = max_poll_attempts | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| self.retry_backoff_factor = retry_backoff_factor | |
| self.estimated_duration = estimated_duration | |
| self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) | |
| self.progress_extractor = progress_extractor | |
| self.result_url_extractor = result_url_extractor | |
| self.node_id = node_id | |
| self.completed_statuses = completed_statuses | |
| self.failed_statuses = failed_statuses | |
| self.final_response: Optional[R] = None | |
| async def execute(self, client: Optional[ApiClient] = None) -> R: | |
| owns_client = client is None | |
| if owns_client: | |
| client = ApiClient( | |
| base_url=self.api_base, | |
| auth_token=self.auth_token, | |
| comfy_api_key=self.comfy_api_key, | |
| max_retries=self.max_retries, | |
| retry_delay=self.retry_delay, | |
| retry_backoff_factor=self.retry_backoff_factor, | |
| ) | |
| try: | |
| return await self._poll_until_complete(client) | |
| finally: | |
| if owns_client: | |
| await client.close() | |
| def _display_text_on_node(self, text: str): | |
| if not self.node_id: | |
| return | |
| PromptServer.instance.send_progress_text(text, self.node_id) | |
| def _display_time_progress_on_node(self, time_completed: int | float): | |
| if not self.node_id: | |
| return | |
| if self.estimated_duration is not None: | |
| remaining = max(0, int(self.estimated_duration) - time_completed) | |
| message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" | |
| else: | |
| message = f"Task in progress: {time_completed}s" | |
| self._display_text_on_node(message) | |
| def _check_task_status(self, response: R) -> TaskStatus: | |
| try: | |
| status = self.status_extractor(response) | |
| if status in self.completed_statuses: | |
| return TaskStatus.COMPLETED | |
| if status in self.failed_statuses: | |
| return TaskStatus.FAILED | |
| return TaskStatus.PENDING | |
| except Exception as e: | |
| logging.error("Error extracting status: %s", e) | |
| return TaskStatus.PENDING | |
| async def _poll_until_complete(self, client: ApiClient) -> R: | |
| """Poll until the task is complete""" | |
| consecutive_errors = 0 | |
| max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors | |
| if self.progress_extractor: | |
| progress = utils.ProgressBar(PROGRESS_BAR_MAX) | |
| status = TaskStatus.PENDING | |
| for poll_count in range(1, self.max_poll_attempts + 1): | |
| try: | |
| logging.debug(f"[DEBUG] Polling attempt #{poll_count}") | |
| request_dict = ( | |
| None if self.request is None else self.request.model_dump(exclude_none=True) | |
| ) | |
| if poll_count == 1: | |
| logging.debug( | |
| f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" | |
| ) | |
| logging.debug( | |
| f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" | |
| ) | |
| # Query task status | |
| resp = await client.request( | |
| self.poll_endpoint.method.value, | |
| self.poll_endpoint.path, | |
| params=self.poll_endpoint.query_params, | |
| data=request_dict, | |
| ) | |
| consecutive_errors = 0 # reset on success | |
| response_obj: R = self.poll_endpoint.response_model.model_validate(resp) | |
| # Check if task is complete | |
| status = self._check_task_status(response_obj) | |
| logging.debug(f"[DEBUG] Task Status: {status}") | |
| # If progress extractor is provided, extract progress | |
| if self.progress_extractor: | |
| new_progress = self.progress_extractor(response_obj) | |
| if new_progress is not None: | |
| progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) | |
| if status == TaskStatus.COMPLETED: | |
| message = "Task completed successfully" | |
| if self.result_url_extractor: | |
| result_url = self.result_url_extractor(response_obj) | |
| if result_url: | |
| message = f"Result URL: {result_url}" | |
| logging.debug(f"[DEBUG] {message}") | |
| self._display_text_on_node(message) | |
| self.final_response = response_obj | |
| if self.progress_extractor: | |
| progress.update(100) | |
| return self.final_response | |
| if status == TaskStatus.FAILED: | |
| message = f"Task failed: {json.dumps(resp)}" | |
| logging.error(f"[DEBUG] {message}") | |
| raise Exception(message) | |
| logging.debug("[DEBUG] Task still pending, continuing to poll...") | |
| # Task pending – wait | |
| for i in range(int(self.poll_interval)): | |
| self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) | |
| await asyncio.sleep(1) | |
| except (LocalNetworkError, ApiServerError, NetworkError) as e: | |
| consecutive_errors += 1 | |
| if consecutive_errors >= max_consecutive_errors: | |
| raise Exception( | |
| f"Polling aborted after {consecutive_errors} network errors: {str(e)}" | |
| ) from e | |
| logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) | |
| await asyncio.sleep(self.poll_interval) | |
| except Exception as e: | |
| # For other errors, increment count and potentially abort | |
| consecutive_errors += 1 | |
| if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: | |
| raise Exception( | |
| f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" | |
| ) from e | |
| logging.error(f"[DEBUG] Polling error: {str(e)}") | |
| logging.warning( | |
| f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " | |
| f"Will retry in {self.poll_interval} seconds." | |
| ) | |
| await asyncio.sleep(self.poll_interval) | |
| # If we've exhausted all polling attempts | |
| raise Exception( | |
| f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " | |
| "The operation may still be running on the server but is taking longer than expected." | |
| ) | |