Spaces:
Running
Running
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import replicate | |
| from dotenv import load_dotenv | |
| from api.flux import FluxAPI | |
| class PrunaAPI(FluxAPI): | |
| def __init__(self, speed_mode: str): | |
| self._speed_mode = speed_mode | |
| self._speed_mode_name = ( | |
| speed_mode.split(" ")[0].strip().lower().replace(" ", "_") | |
| ) | |
| load_dotenv() | |
| self._api_key = os.getenv("REPLICATE_API_TOKEN") | |
| if not self._api_key: | |
| raise ValueError("REPLICATE_API_TOKEN not found in environment variables") | |
| def name(self) -> str: | |
| return f"pruna_{self._speed_mode_name}" | |
| def generate_image(self, prompt: str, save_path: Path) -> float: | |
| start_time = time.time() | |
| result = replicate.run( | |
| "prunaai/flux.1-juiced:58977759ff2870cc010597ae75f4d87866d169b248e02b6e86c4e1bf8afe2410", | |
| input={ | |
| "seed": 0, | |
| "prompt": prompt, | |
| "guidance": 3.5, | |
| "num_outputs": 1, | |
| "aspect_ratio": "1:1", | |
| "output_format": "png", | |
| "speed_mode": self._speed_mode, | |
| "num_inference_steps": 28, | |
| }, | |
| ) | |
| end_time = time.time() | |
| if result: | |
| self._save_image_from_result(result, save_path) | |
| else: | |
| raise Exception("No result returned from Replicate API") | |
| return end_time - start_time | |
| def _save_image_from_result(self, result: Any, save_path: Path): | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(save_path, "wb") as f: | |
| f.write(result.read()) | |