from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends from pydantic import BaseModel from fastapi.staticfiles import StaticFiles from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware import uuid import time import tempfile from concurrent.futures import ThreadPoolExecutor from pymongo import MongoClient from urllib.parse import quote_plus from langchain_groq import ChatGroq from aura_sr import AuraSR from io import BytesIO from PIL import Image import requests import os import logging from dotenv import load_dotenv # Load environment variables load_dotenv() # Validate environment variables assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!" assert os.getenv('LLM_API_KEY'), "LLM API Key missing!" assert os.getenv('BFL_API_KEY'), "BFL API Key missing!" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set the Hugging Face cache directory to a writable location os.environ['HF_HOME'] = '/tmp/huggingface_cache' app = FastAPI() # Middleware for CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Globals executor = ThreadPoolExecutor(max_workers=5) llm = None mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/") db = mongo_client["Flux"] collection = db["chat_histories"] chat_sessions = {} # Temporary directory for storing images image_storage_dir = tempfile.mkdtemp() app.mount("/images", StaticFiles(directory=image_storage_dir), name="images") # Initialize AuraSR during startup aura_sr = None @app.on_event("startup") async def startup(): global llm, aura_sr try: llm = ChatGroq( model="llama-3.3-70b-versatile", temperature=0.7, max_tokens=1024, api_key=os.getenv('LLM_API_KEY'), ) aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2") except Exception as e: logger.error(f"Error initializing models: {e}") @app.on_event("shutdown") def shutdown(): mongo_client.close() executor.shutdown() # Pydantic models class ImageRequest(BaseModel): subject: str style: str color_theme: str elements: str color_mode: str lighting_conditions: str framing_style: str material_details: str text: str background_details: str user_prompt: str chat_id: str class UpscaleRequest(BaseModel): image_url: str # Helper functions def generate_chat_id(): chat_id = str(uuid.uuid4()) chat_sessions[chat_id] = collection return chat_id def get_chat_history(chat_id): messages = collection.find({"session_id": chat_id}) return "\n".join( f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}" for msg in messages ) def save_image_locally(image, filename): filepath = os.path.join(image_storage_dir, filename) image.save(filepath, format="PNG") return filepath def make_request_with_retries(url, headers, payload, retries=3, delay=2): """ Makes an HTTP POST request with retries in case of failure. :param url: The URL for the request. :param headers: Headers to include in the request. :param payload: Payload to include in the request. :param retries: Number of retries on failure. :param delay: Delay between retries. :return: Response JSON from the server. """ for attempt in range(retries): try: with requests.Session() as session: response = session.post(url, headers=headers, json=payload, timeout=30) response.raise_for_status() return response.json() except requests.exceptions.RequestException as e: if attempt < retries - 1: time.sleep(delay) continue else: raise HTTPException(status_code=500, detail=f"Request failed after {retries} attempts: {str(e)}") def fetch_image(url): try: with requests.Session() as session: response = session.get(url, timeout=30) response.raise_for_status() return Image.open(BytesIO(response.content)) except Exception as e: raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}") def poll_for_image_result(request_id, headers): timeout = 60 start_time = time.time() while time.time() - start_time < timeout: time.sleep(0.5) with requests.Session() as session: result = session.get( "https://api.bfl.ml/v1/get_result", headers=headers, params={"id": request_id}, timeout=10 ).json() if result["status"] == "Ready": return result["result"].get("sample") elif result["status"] == "Error": raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}") raise HTTPException(status_code=500, detail="Image generation timed out.") @app.post("/new-chat", response_model=dict) async def new_chat(): chat_id = generate_chat_id() return {"chat_id": chat_id} @app.post("/generate-image", response_model=dict) async def generate_image(request: ImageRequest): chat_history = get_chat_history(request.chat_id) prompt = f""" You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image. Image Specifications: - **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures. - **Style**: Emphasize the **{request.style}**, capturing its key characteristics. - **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition. - **Camera and Lighting**: - Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background. - **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background. - **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background. - **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background. - **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene. - Negative Prompt: Avoid grainy, blurry, or deformed outputs. - **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way. """ refined_prompt = llm.invoke(prompt).content.strip() collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt}) collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt}) headers = { "accept": "application/json", "x-key": os.getenv('BFL_API_KEY'), "Content-Type": "application/json" } payload = { "prompt": refined_prompt, "width": 1024, "height": 1024, "guidance_scale": 1, "num_inference_steps": 50, "max_sequence_length": 512, } response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload) if "id" not in response: raise HTTPException(status_code=500, detail="Error generating image: ID missing from response") image_url = poll_for_image_result(response["id"], headers) image = fetch_image(image_url) filename = f"generated_{uuid.uuid4()}.png" filepath = save_image_locally(image, filename) return { "status": "Image generated successfully", "file_path": filepath, "file_url": f"/images/{filename}", } @app.post("/upscale-image", response_model=dict) async def upscale_image(request: UpscaleRequest): if aura_sr is None: raise HTTPException(status_code=500, detail="Upscaling model not initialized.") img = await run_in_threadpool(fetch_image, request.image_url) def perform_upscaling(): upscaled_image = aura_sr.upscale_4x_overlapped(img) filename = f"upscaled_{uuid.uuid4()}.png" return save_image_locally(upscaled_image, filename) future = executor.submit(perform_upscaling) filepath = await run_in_threadpool(lambda: future.result()) return { "status": "Upscaling successful", "file_path": filepath, "file_url": f"/images/{os.path.basename(filepath)}", } @app.get("/") async def root(): return {"message": "API is up and running!"}