File size: 5,429 Bytes
afb8be0 8daf03a afb8be0 8daf03a afb8be0 8daf03a afb8be0 8daf03a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import requests
import numpy as np
from PIL import Image
import io
import base64
import logging
import sys
import traceback
import os
# Configure logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("faceforge_ui")
# API configuration
API_URL = os.environ.get("API_URL", "http://localhost:8000")
logger.info(f"Using API URL: {API_URL}")
def generate_image(prompts, mode, player_x, player_y):
"""
Generate an image based on prompts and player position.
Args:
prompts: Comma-separated list of prompts
mode: Sampling mode ('distance' or 'circle')
player_x: X-coordinate of player position
player_y: Y-coordinate of player position
Returns:
PIL.Image or None: Generated image or None if generation failed
"""
try:
logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
# Parse prompts
prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
if not prompt_list:
logger.warning("No valid prompts provided")
return None
logger.debug(f"Parsed prompts: {prompt_list}")
# Prepare request
req = {
"prompts": prompt_list,
"mode": mode,
"player_pos": [float(player_x), float(player_y)]
}
logger.debug(f"Sending request to API: {req}")
# Make API call
try:
resp = requests.post(f"{API_URL}/generate", json=req, timeout=30)
logger.debug(f"API response status: {resp.status_code}")
if resp.ok:
data = resp.json()
logger.debug("Successfully received API response")
if "image" in data:
img_b64 = data["image"]
img_bytes = base64.b64decode(img_b64)
try:
img = Image.frombytes("RGB", (256, 256), img_bytes)
logger.debug("Successfully decoded image")
return img
except Exception as e:
logger.error(f"Error decoding image: {e}")
logger.debug(traceback.format_exc())
return None
else:
logger.warning("No image in API response")
return None
else:
logger.error(f"API error: {resp.status_code} - {resp.text}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"Request failed: {e}")
logger.debug(traceback.format_exc())
return None
except Exception as e:
logger.error(f"Unexpected error: {e}")
logger.debug(traceback.format_exc())
return None
# Create Gradio interface
logger.info("Initializing Gradio interface")
with gr.Blocks() as demo:
gr.Markdown("# FaceForge Latent Space Explorer")
with gr.Row():
with gr.Column():
prompts = gr.Textbox(
label="Prompts (comma-separated)",
value="A photo of a cat, A photo of a dog",
info="Enter prompts separated by commas"
)
mode = gr.Radio(
choices=["distance", "circle"],
value="distance",
label="Sampling Mode",
info="Choose how to sample the latent space"
)
player_x = gr.Slider(-1.0, 1.0, value=0.0, label="Player X")
player_y = gr.Slider(-1.0, 1.0, value=0.0, label="Player Y")
btn = gr.Button("Generate")
with gr.Column():
img = gr.Image(label="Generated Image")
status = gr.Textbox(label="Status", interactive=False)
def on_generate_click(prompts, mode, player_x, player_y):
try:
logger.info("Generate button clicked")
result = generate_image(prompts, mode, player_x, player_y)
if result is not None:
return [result, "Image generated successfully"]
else:
return [None, "Failed to generate image. Check logs for details."]
except Exception as e:
logger.error(f"Error in generate button handler: {e}")
logger.debug(traceback.format_exc())
return [None, f"Error: {str(e)}"]
btn.click(
fn=on_generate_click,
inputs=[prompts, mode, player_x, player_y],
outputs=[img, status]
)
demo.load(lambda: "Ready to generate images", outputs=status)
if __name__ == "__main__":
logger.info("Starting Gradio app")
try:
# Check if we're running in Hugging Face Spaces
if "SPACE_ID" in os.environ:
logger.info("Running in Hugging Face Space")
demo.launch(server_name="0.0.0.0", share=False)
else:
logger.info("Running locally")
demo.launch(server_name="0.0.0.0", share=False)
except Exception as e:
logger.critical(f"Failed to launch Gradio app: {e}")
logger.debug(traceback.format_exc()) |