import gradio as gr import numpy as np import random import os import base64 import requests import time import io from PIL import Image, ImageOps import pillow_heif # For HEIF/AVIF support # --- Constants --- MAX_SEED = np.iinfo(np.int32).max API_URL = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev?_subdomain=queue" def get_headers(): """Get headers for API requests""" hf_token = os.getenv("HF_TOKEN") if not hf_token: raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.") return { "Authorization": f"Bearer {hf_token}", "X-HF-Bill-To": "huggingface" } def query_api(payload, progress_callback=None): """Send request to the API and return response""" headers = get_headers() # Submit the job response = requests.post(API_URL, headers=headers, json=payload) if response.status_code != 200: raise gr.Error(f"API request failed with status {response.status_code}: {response.text}") # Parse the initial response try: json_response = response.json() print(f"Initial response: {json_response}") except: raise gr.Error("Failed to parse initial API response as JSON") # Check if job was queued if json_response.get("status") == "IN_QUEUE": status_url = json_response.get("status_url") if not status_url: raise gr.Error("No status URL provided in queue response") # Poll for completion max_attempts = 60 # Wait up to 5 minutes (60 * 5 seconds) attempt = 0 while attempt < max_attempts: if progress_callback: progress_callback(0.1 + (attempt / max_attempts) * 0.8, f"Processing... (attempt {attempt + 1}/60)") time.sleep(5) # Wait 5 seconds between polls # Check status status_response = requests.get(status_url, headers=headers) if status_response.status_code != 200: raise gr.Error(f"Status check failed: {status_response.status_code}") try: status_data = status_response.json() print(f"Status check {attempt + 1}: {status_data}") if status_data.get("status") == "COMPLETED": # Job completed, get the result response_url = json_response.get("response_url") if not response_url: raise gr.Error("No response URL provided") result_response = requests.get(response_url, headers=headers) if result_response.status_code != 200: raise gr.Error(f"Failed to get result: {result_response.status_code}") # Check if result is JSON with image data try: result_data = result_response.json() print(f"Result data: {result_data}") # Look for image in various possible fields if 'images' in result_data and len(result_data['images']) > 0: # Images array with URLs or base64 image_data = result_data['images'][0] if isinstance(image_data, dict) and 'url' in image_data: # Image URL - fetch it img_response = requests.get(image_data['url']) return img_response.content elif isinstance(image_data, str): # Assume base64 return base64.b64decode(image_data) elif 'image' in result_data: # Single image field if isinstance(result_data['image'], str): return base64.b64decode(result_data['image']) elif 'url' in result_data: # Direct URL img_response = requests.get(result_data['url']) return img_response.content else: raise gr.Error(f"No image found in result: {result_data}") except requests.exceptions.JSONDecodeError: # Result might be direct image bytes return result_response.content elif status_data.get("status") == "FAILED": error_msg = status_data.get("error", "Unknown error") raise gr.Error(f"Job failed: {error_msg}") # Still processing, continue polling attempt += 1 except requests.exceptions.JSONDecodeError: raise gr.Error("Failed to parse status response") raise gr.Error("Job timed out after 5 minutes") elif json_response.get("status") == "COMPLETED": # Job completed immediately if 'images' in json_response and len(json_response['images']) > 0: image_data = json_response['images'][0] if isinstance(image_data, str): return base64.b64decode(image_data) elif 'image' in json_response: return base64.b64decode(json_response['image']) else: raise gr.Error(f"No image found in immediate response: {json_response}") else: raise gr.Error(f"Unexpected response status: {json_response.get('status', 'unknown')}") # --- Core Inference Function for ChatInterface --- def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): """ Performs image generation or editing based on user input from the chat interface. """ # Register HEIF opener with PIL for AVIF/HEIF support pillow_heif.register_heif_opener() prompt = message["text"] files = message["files"] if not prompt and not files: raise gr.Error("Please provide a prompt and/or upload an image.") if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare the payload payload = { "parameters": { "prompt": prompt, "seed": seed, "guidance_scale": guidance_scale, "num_inference_steps": steps } } if files: print(f"Received image: {files[0]}") try: # Try to open and convert the image input_image = Image.open(files[0]) # Convert to RGB if needed (handles RGBA, P, etc.) if input_image.mode != "RGB": input_image = input_image.convert("RGB") # Auto-orient the image based on EXIF data input_image = ImageOps.exif_transpose(input_image) # Convert PIL image to base64 for the API img_byte_arr = io.BytesIO() input_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) image_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') # Add image to payload for image-to-image payload["inputs"] = image_base64 except Exception as e: raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).") progress(0.1, desc="Processing image...") else: print(f"Received prompt for text-to-image: {prompt}") # For text-to-image, we don't need the inputs field progress(0.1, desc="Generating image...") try: # Make API request image_bytes = query_api(payload) # Try to convert response bytes to PIL Image with better error handling try: image = Image.open(io.BytesIO(image_bytes)) except Exception as img_error: print(f"Failed to open image directly: {img_error}") # Maybe it's a different format, try to save and examine with open('/tmp/debug_response.bin', 'wb') as f: f.write(image_bytes) print(f"Saved response to /tmp/debug_response.bin for debugging") # Try to decode as base64 if direct opening failed try: decoded_bytes = base64.b64decode(image_bytes) image = Image.open(io.BytesIO(decoded_bytes)) except: raise gr.Error(f"Could not process API response as image. Response type: {type(image_bytes)}, Length: {len(image_bytes) if isinstance(image_bytes, (bytes, str)) else 'unknown'}") progress(1.0, desc="Complete!") return gr.Image(value=image) except gr.Error: # Re-raise gradio errors as-is raise except Exception as e: raise gr.Error(f"Failed to generate image: {str(e)}") # --- UI Definition using gr.ChatInterface --- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5) steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1) demo = gr.ChatInterface( fn=chat_fn, title="FLUX.1 Kontext [dev] - Direct API", description="""
A simple chat UI for the FLUX.1 Kontext model using direct API calls with requests.
To edit an image, upload it and type your instructions (e.g., "Add a hat").
To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
Find the model on Hugging Face.