Spaces:
Runtime error
Runtime error
| import io | |
| import os | |
| import time | |
| from pathlib import Path | |
| import requests | |
| from PIL import Image | |
| API_URL = "https://api.bfl.ml" | |
| API_ENDPOINTS = { | |
| "flux.1-pro": "flux-pro", | |
| "flux.1-dev": "flux-dev", | |
| "flux.1.1-pro": "flux-pro-1.1", | |
| } | |
| class ApiException(Exception): | |
| def __init__(self, status_code: int, detail: str | list[dict] | None = None): | |
| super().__init__() | |
| self.detail = detail | |
| self.status_code = status_code | |
| def __str__(self) -> str: | |
| return self.__repr__() | |
| def __repr__(self) -> str: | |
| if self.detail is None: | |
| message = None | |
| elif isinstance(self.detail, str): | |
| message = self.detail | |
| else: | |
| message = "[" + ",".join(d["msg"] for d in self.detail) + "]" | |
| return f"ApiException({self.status_code=}, {message=}, detail={self.detail})" | |
| class ImageRequest: | |
| def __init__( | |
| self, | |
| # api inputs | |
| prompt: str, | |
| name: str = "flux.1.1-pro", | |
| width: int | None = None, | |
| height: int | None = None, | |
| num_steps: int | None = None, | |
| prompt_upsampling: bool | None = None, | |
| seed: int | None = None, | |
| guidance: float | None = None, | |
| interval: float | None = None, | |
| safety_tolerance: int | None = None, | |
| # behavior of this class | |
| validate: bool = True, | |
| launch: bool = True, | |
| api_key: str | None = None, | |
| ): | |
| """ | |
| Manages an image generation request to the API. | |
| All parameters not specified will use the API defaults. | |
| Args: | |
| prompt: Text prompt for image generation. | |
| width: Width of the generated image in pixels. Must be a multiple of 32. | |
| height: Height of the generated image in pixels. Must be a multiple of 32. | |
| name: Which model version to use | |
| num_steps: Number of steps for the image generation process. | |
| prompt_upsampling: Whether to perform upsampling on the prompt. | |
| seed: Optional seed for reproducibility. | |
| guidance: Guidance scale for image generation. | |
| safety_tolerance: Tolerance level for input and output moderation. | |
| Between 0 and 6, 0 being most strict, 6 being least strict. | |
| validate: Run input validation | |
| launch: Directly launches request | |
| api_key: Your API key if not provided by the environment | |
| Raises: | |
| ValueError: For invalid input, when `validate` | |
| ApiException: For errors raised from the API | |
| """ | |
| if validate: | |
| if name not in API_ENDPOINTS.keys(): | |
| raise ValueError(f"Invalid model {name}") | |
| elif width is not None and width % 32 != 0: | |
| raise ValueError(f"width must be divisible by 32, got {width}") | |
| elif width is not None and not (256 <= width <= 1440): | |
| raise ValueError(f"width must be between 256 and 1440, got {width}") | |
| elif height is not None and height % 32 != 0: | |
| raise ValueError(f"height must be divisible by 32, got {height}") | |
| elif height is not None and not (256 <= height <= 1440): | |
| raise ValueError(f"height must be between 256 and 1440, got {height}") | |
| elif num_steps is not None and not (1 <= num_steps <= 50): | |
| raise ValueError(f"steps must be between 1 and 50, got {num_steps}") | |
| elif guidance is not None and not (1.5 <= guidance <= 5.0): | |
| raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}") | |
| elif interval is not None and not (1.0 <= interval <= 4.0): | |
| raise ValueError(f"interval must be between 1 and 4, got {interval}") | |
| elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0): | |
| raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}") | |
| if name == "flux.1-dev": | |
| if interval is not None: | |
| raise ValueError("Interval is not supported for flux.1-dev") | |
| if name == "flux.1.1-pro": | |
| if interval is not None or num_steps is not None or guidance is not None: | |
| raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro") | |
| self.name = name | |
| self.request_json = { | |
| "prompt": prompt, | |
| "width": width, | |
| "height": height, | |
| "steps": num_steps, | |
| "prompt_upsampling": prompt_upsampling, | |
| "seed": seed, | |
| "guidance": guidance, | |
| "interval": interval, | |
| "safety_tolerance": safety_tolerance, | |
| } | |
| self.request_json = {key: value for key, value in self.request_json.items() if value is not None} | |
| self.request_id: str | None = None | |
| self.result: dict | None = None | |
| self._image_bytes: bytes | None = None | |
| self._url: str | None = None | |
| if api_key is None: | |
| self.api_key = os.environ.get("BFL_API_KEY") | |
| else: | |
| self.api_key = api_key | |
| if launch: | |
| self.request() | |
| def request(self): | |
| """ | |
| Request to generate the image. | |
| """ | |
| if self.request_id is not None: | |
| return | |
| response = requests.post( | |
| f"{API_URL}/v1/{API_ENDPOINTS[self.name]}", | |
| headers={ | |
| "accept": "application/json", | |
| "x-key": self.api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| json=self.request_json, | |
| ) | |
| result = response.json() | |
| if response.status_code != 200: | |
| raise ApiException(status_code=response.status_code, detail=result.get("detail")) | |
| self.request_id = response.json()["id"] | |
| def retrieve(self) -> dict: | |
| """ | |
| Wait for the generation to finish and retrieve response. | |
| """ | |
| if self.request_id is None: | |
| self.request() | |
| while self.result is None: | |
| response = requests.get( | |
| f"{API_URL}/v1/get_result", | |
| headers={ | |
| "accept": "application/json", | |
| "x-key": self.api_key, | |
| }, | |
| params={ | |
| "id": self.request_id, | |
| }, | |
| ) | |
| result = response.json() | |
| if "status" not in result: | |
| raise ApiException(status_code=response.status_code, detail=result.get("detail")) | |
| elif result["status"] == "Ready": | |
| self.result = result["result"] | |
| elif result["status"] == "Pending": | |
| time.sleep(0.5) | |
| else: | |
| raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'") | |
| return self.result | |
| def bytes(self) -> bytes: | |
| """ | |
| Generated image as bytes. | |
| """ | |
| if self._image_bytes is None: | |
| response = requests.get(self.url) | |
| if response.status_code == 200: | |
| self._image_bytes = response.content | |
| else: | |
| raise ApiException(status_code=response.status_code) | |
| return self._image_bytes | |
| def url(self) -> str: | |
| """ | |
| Public url to retrieve the image from | |
| """ | |
| if self._url is None: | |
| result = self.retrieve() | |
| self._url = result["sample"] | |
| return self._url | |
| def image(self) -> Image.Image: | |
| """ | |
| Load the image as a PIL Image | |
| """ | |
| return Image.open(io.BytesIO(self.bytes)) | |
| def save(self, path: str): | |
| """ | |
| Save the generated image to a local path | |
| """ | |
| suffix = Path(self.url).suffix | |
| if not path.endswith(suffix): | |
| path = path + suffix | |
| Path(path).resolve().parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "wb") as file: | |
| file.write(self.bytes) | |
| if __name__ == "__main__": | |
| from fire import Fire | |
| Fire(ImageRequest) | |