import gradio as gr import numpy as np import random import os import base64 import requests import time import io from PIL import Image, ImageOps import pillow_heif # For HEIF/AVIF support # --- Constants --- MAX_SEED = np.iinfo(np.int32).max API_URL = "https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev?_subdomain=queue" def get_headers(): """Get headers for Hugging Face router API requests""" hf_token = os.getenv("HF_TOKEN") if not hf_token: raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.") return { "Authorization": f"Bearer {hf_token}", "X-HF-Bill-To": "huggingface" } def query_api(payload, progress_callback=None): """Send request to the API and return response""" headers = get_headers() # Convert image to base64 if it's bytes if "image_bytes" in payload: payload["inputs"] = base64.b64encode(payload["image_bytes"]).decode("utf-8") del payload["image_bytes"] # Submit the job response = requests.post(API_URL, headers=headers, json=payload) if response.status_code != 200: raise gr.Error(f"API request failed with status {response.status_code}: {response.text}") # Debug the response print(f"Response status: {response.status_code}") print(f"Response headers: {dict(response.headers)}") print(f"Response content type: {response.headers.get('content-type', 'unknown')}") print(f"Response content length: {len(response.content)}") print(f"First 500 chars of response: {response.content[:500]}") # Check if response is JSON (queue status) or binary (direct image) content_type = response.headers.get('content-type', '').lower() if 'application/json' in content_type: # Response is JSON, likely queue status try: json_response = response.json() print(f"JSON response: {json_response}") # Check if job was queued if json_response.get("status") == "IN_QUEUE": request_id = json_response.get("request_id") if not request_id: raise gr.Error("No request_id provided in queue response") # Poll for completion using the proper HF router endpoints max_attempts = 60 # Wait up to 5 minutes attempt = 0 while attempt < max_attempts: if progress_callback: progress_callback(0.1 + (attempt / max_attempts) * 0.8, f"Processing... (attempt {attempt + 1}/60)") time.sleep(5) # Wait 5 seconds between polls # Check status using HF router format status_url = f"https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev/requests/{request_id}/status" status_response = requests.get(status_url, headers=headers) if status_response.status_code != 200: print(f"Status response: {status_response.status_code} - {status_response.text}") # Continue polling even if status check fails temporarily attempt += 1 continue try: status_data = status_response.json() print(f"Status check {attempt + 1}: {status_data}") if status_data.get("status") == "COMPLETED": # Job completed, get the result result_url = f"https://router.huggingface.co/fal-ai/fal-ai/flux-kontext/dev/requests/{request_id}" result_response = requests.get(result_url, headers=headers) if result_response.status_code != 200: print(f"Result response: {result_response.status_code} - {result_response.text}") raise gr.Error(f"Failed to get result: {result_response.status_code}") # Check if result is direct image bytes or JSON result_content_type = result_response.headers.get('content-type', '').lower() if 'image/' in result_content_type: # Direct image bytes return result_response.content else: # Try to parse as JSON for image URL or base64 try: result_data = result_response.json() print(f"Result data: {result_data}") # Look for images in various formats if 'images' in result_data and len(result_data['images']) > 0: image_info = result_data['images'][0] if isinstance(image_info, dict) and 'url' in image_info: # Download the image img_response = requests.get(image_info['url']) return img_response.content elif isinstance(image_info, str): # Base64 encoded return base64.b64decode(image_info) elif 'image' in result_data: # Single image field if isinstance(result_data['image'], str): return base64.b64decode(result_data['image']) else: # Maybe it's direct image bytes return result_response.content except requests.exceptions.JSONDecodeError: # Result might be direct image bytes return result_response.content elif status_data.get("status") == "FAILED": error_msg = status_data.get("error", "Unknown error") raise gr.Error(f"Job failed: {error_msg}") # Still processing, continue polling attempt += 1 except requests.exceptions.JSONDecodeError: print("Failed to parse status response, continuing...") attempt += 1 continue raise gr.Error("Job timed out after 5 minutes") elif json_response.get("status") == "COMPLETED": # Job completed immediately if 'images' in json_response and len(json_response['images']) > 0: image_info = json_response['images'][0] if isinstance(image_info, dict) and 'url' in image_info: img_response = requests.get(image_info['url']) return img_response.content elif isinstance(image_info, str): return base64.b64decode(image_info) elif 'image' in json_response: return base64.b64decode(json_response['image']) else: raise gr.Error(f"No images found in immediate response: {json_response}") else: raise gr.Error(f"Unexpected response status: {json_response.get('status', 'unknown')}") except requests.exceptions.JSONDecodeError as e: raise gr.Error(f"Failed to parse JSON response: {str(e)}") elif 'image/' in content_type: # Response is direct image bytes return response.content else: # Unknown content type, try to handle as bytes return response.content def upload_image_to_fal(image_bytes): """Upload image to fal.ai and return the URL""" # For now, we'll use base64 data URI as mentioned in the docs # fal.ai supports base64 data URIs for image_url image_base64 = base64.b64encode(image_bytes).decode('utf-8') # Detect image format try: img = Image.open(io.BytesIO(image_bytes)) format_map = {'JPEG': 'jpeg', 'PNG': 'png', 'WEBP': 'webp'} img_format = format_map.get(img.format, 'jpeg') except: img_format = 'jpeg' return f"data:image/{img_format};base64,{image_base64}" """Send request to the API and return response""" hf_headers = get_headers() # Submit the job response = requests.post(API_URL, headers=hf_headers, json=payload) if response.status_code != 200: raise gr.Error(f"API request failed with status {response.status_code}: {response.text}") # Parse the initial response try: json_response = response.json() print(f"Initial response: {json_response}") except: raise gr.Error("Failed to parse initial API response as JSON") # Check if job was queued if json_response.get("status") == "IN_QUEUE": status_url = json_response.get("status_url") if not status_url: raise gr.Error("No status URL provided in queue response") # For fal.ai endpoints, we need different headers fal_headers = get_fal_headers() # Poll for completion max_attempts = 60 # Wait up to 5 minutes (60 * 5 seconds) attempt = 0 while attempt < max_attempts: if progress_callback: progress_callback(0.1 + (attempt / max_attempts) * 0.8, f"Processing... (attempt {attempt + 1}/60)") time.sleep(5) # Wait 5 seconds between polls # Check status with fal.ai headers status_response = requests.get(status_url, headers=fal_headers) if status_response.status_code != 200: print(f"Status response: {status_response.status_code} - {status_response.text}") raise gr.Error(f"Status check failed: {status_response.status_code}") try: status_data = status_response.json() print(f"Status check {attempt + 1}: {status_data}") if status_data.get("status") == "COMPLETED": # Job completed, get the result response_url = json_response.get("response_url") if not response_url: raise gr.Error("No response URL provided") # Get result with fal.ai headers result_response = requests.get(response_url, headers=fal_headers) if result_response.status_code != 200: print(f"Result response: {result_response.status_code} - {result_response.text}") raise gr.Error(f"Failed to get result: {result_response.status_code}") # Check if result is JSON with image data try: result_data = result_response.json() print(f"Result data: {result_data}") # Look for image in various possible fields if 'images' in result_data and len(result_data['images']) > 0: # Images array with URLs or base64 image_data = result_data['images'][0] if isinstance(image_data, dict) and 'url' in image_data: # Image URL - fetch it img_response = requests.get(image_data['url']) return img_response.content elif isinstance(image_data, str): # Assume base64 return base64.b64decode(image_data) elif 'image' in result_data: # Single image field if isinstance(result_data['image'], str): return base64.b64decode(result_data['image']) elif 'url' in result_data: # Direct URL img_response = requests.get(result_data['url']) return img_response.content else: raise gr.Error(f"No image found in result: {result_data}") except requests.exceptions.JSONDecodeError: # Result might be direct image bytes return result_response.content elif status_data.get("status") == "FAILED": error_msg = status_data.get("error", "Unknown error") raise gr.Error(f"Job failed: {error_msg}") # Still processing, continue polling attempt += 1 except requests.exceptions.JSONDecodeError: raise gr.Error("Failed to parse status response") raise gr.Error("Job timed out after 5 minutes") elif json_response.get("status") == "COMPLETED": # Job completed immediately if 'images' in json_response and len(json_response['images']) > 0: image_data = json_response['images'][0] if isinstance(image_data, str): return base64.b64decode(image_data) elif 'image' in json_response: return base64.b64decode(json_response['image']) else: raise gr.Error(f"No image found in immediate response: {json_response}") else: raise gr.Error(f"Unexpected response status: {json_response.get('status', 'unknown')}") # --- Core Inference Function for ChatInterface --- def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): """ Performs image generation or editing based on user input from the chat interface. """ # Register HEIF opener with PIL for AVIF/HEIF support pillow_heif.register_heif_opener() prompt = message["text"] files = message["files"] if not prompt and not files: raise gr.Error("Please provide a prompt and/or upload an image.") if randomize_seed: seed = random.randint(0, MAX_SEED) # Prepare the payload for Hugging Face router payload = { "parameters": { "prompt": prompt, "seed": seed, "guidance_scale": guidance_scale, "num_inference_steps": steps } } if files: print(f"Received image: {files[0]}") try: # Try to open and convert the image input_image = Image.open(files[0]) # Convert to RGB if needed (handles RGBA, P, etc.) if input_image.mode != "RGB": input_image = input_image.convert("RGB") # Auto-orient the image based on EXIF data input_image = ImageOps.exif_transpose(input_image) # Convert PIL image to bytes img_byte_arr = io.BytesIO() input_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) image_bytes = img_byte_arr.getvalue() # Add image bytes to payload - will be converted to base64 in query_api payload["image_bytes"] = image_bytes except Exception as e: raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).") progress(0.1, desc="Processing image...") else: print(f"Received prompt for text-to-image: {prompt}") # For text-to-image, we don't need an input image progress(0.1, desc="Generating image...") try: # Make API request with progress callback image_bytes = query_api(payload, progress_callback=progress) # Try to convert response bytes to PIL Image try: image = Image.open(io.BytesIO(image_bytes)) except Exception as img_error: print(f"Failed to open image: {img_error}") print(f"Image bytes type: {type(image_bytes)}, length: {len(image_bytes) if hasattr(image_bytes, '__len__') else 'unknown'}") # Try to decode as base64 if direct opening failed try: decoded_bytes = base64.b64decode(image_bytes) image = Image.open(io.BytesIO(decoded_bytes)) except: raise gr.Error(f"Could not process API response as image. Response length: {len(image_bytes) if hasattr(image_bytes, '__len__') else 'unknown'}") progress(1.0, desc="Complete!") return gr.Image(value=image) except gr.Error: # Re-raise gradio errors as-is raise except Exception as e: raise gr.Error(f"Failed to generate image: {str(e)}") # --- UI Definition using gr.ChatInterface --- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5) steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1) demo = gr.ChatInterface( fn=chat_fn, title="FLUX.1 Kontext [dev] - Hugging Face Router", description="""

A simple chat UI for the FLUX.1 Kontext [dev] model using Hugging Face router.
To edit an image, upload it and type your instructions (e.g., "Add a hat", "Turn the cat into a tiger").
To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
Find the model on Hugging Face.

""", multimodal=True, textbox=gr.MultimodalTextbox( file_types=["image"], placeholder="Type a prompt and/or upload an image...", render=False ), additional_inputs=[ seed_slider, randomize_checkbox, guidance_slider, steps_slider ], theme="soft" ) if __name__ == "__main__": demo.launch()