File size: 5,429 Bytes
afb8be0
 
8daf03a
 
 
 
 
 
 
 
afb8be0
8daf03a
 
 
 
 
afb8be0
8daf03a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb8be0
 
8daf03a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import requests
import numpy as np
from PIL import Image
import io
import base64
import logging
import sys
import traceback
import os

# Configure logging
logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("faceforge_ui")

# API configuration
API_URL = os.environ.get("API_URL", "http://localhost:8000")
logger.info(f"Using API URL: {API_URL}")

def generate_image(prompts, mode, player_x, player_y):
    """
    Generate an image based on prompts and player position.
    
    Args:
        prompts: Comma-separated list of prompts
        mode: Sampling mode ('distance' or 'circle')
        player_x: X-coordinate of player position
        player_y: Y-coordinate of player position
    
    Returns:
        PIL.Image or None: Generated image or None if generation failed
    """
    try:
        logger.debug(f"Generating image with prompts: {prompts}, mode: {mode}, position: ({player_x}, {player_y})")
        
        # Parse prompts
        prompt_list = [p.strip() for p in prompts.split(",") if p.strip()]
        if not prompt_list:
            logger.warning("No valid prompts provided")
            return None
        
        logger.debug(f"Parsed prompts: {prompt_list}")
        
        # Prepare request
        req = {
            "prompts": prompt_list,
            "mode": mode,
            "player_pos": [float(player_x), float(player_y)]
        }
        
        logger.debug(f"Sending request to API: {req}")
        
        # Make API call
        try:
            resp = requests.post(f"{API_URL}/generate", json=req, timeout=30)
            logger.debug(f"API response status: {resp.status_code}")
            
            if resp.ok:
                data = resp.json()
                logger.debug("Successfully received API response")
                
                if "image" in data:
                    img_b64 = data["image"]
                    img_bytes = base64.b64decode(img_b64)
                    
                    try:
                        img = Image.frombytes("RGB", (256, 256), img_bytes)
                        logger.debug("Successfully decoded image")
                        return img
                    except Exception as e:
                        logger.error(f"Error decoding image: {e}")
                        logger.debug(traceback.format_exc())
                        return None
                else:
                    logger.warning("No image in API response")
                    return None
            else:
                logger.error(f"API error: {resp.status_code} - {resp.text}")
                return None
                
        except requests.exceptions.RequestException as e:
            logger.error(f"Request failed: {e}")
            logger.debug(traceback.format_exc())
            return None
            
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        logger.debug(traceback.format_exc())
        return None

# Create Gradio interface
logger.info("Initializing Gradio interface")
with gr.Blocks() as demo:
    gr.Markdown("# FaceForge Latent Space Explorer")
    
    with gr.Row():
        with gr.Column():
            prompts = gr.Textbox(
                label="Prompts (comma-separated)", 
                value="A photo of a cat, A photo of a dog",
                info="Enter prompts separated by commas"
            )
            mode = gr.Radio(
                choices=["distance", "circle"], 
                value="distance", 
                label="Sampling Mode",
                info="Choose how to sample the latent space"
            )
            player_x = gr.Slider(-1.0, 1.0, value=0.0, label="Player X")
            player_y = gr.Slider(-1.0, 1.0, value=0.0, label="Player Y")
            btn = gr.Button("Generate")
        
        with gr.Column():
            img = gr.Image(label="Generated Image")
            status = gr.Textbox(label="Status", interactive=False)
    
    def on_generate_click(prompts, mode, player_x, player_y):
        try:
            logger.info("Generate button clicked")
            result = generate_image(prompts, mode, player_x, player_y)
            if result is not None:
                return [result, "Image generated successfully"]
            else:
                return [None, "Failed to generate image. Check logs for details."]
        except Exception as e:
            logger.error(f"Error in generate button handler: {e}")
            logger.debug(traceback.format_exc())
            return [None, f"Error: {str(e)}"]
    
    btn.click(
        fn=on_generate_click,
        inputs=[prompts, mode, player_x, player_y],
        outputs=[img, status]
    )
    
    demo.load(lambda: "Ready to generate images", outputs=status)

if __name__ == "__main__":
    logger.info("Starting Gradio app")
    try:
        # Check if we're running in Hugging Face Spaces
        if "SPACE_ID" in os.environ:
            logger.info("Running in Hugging Face Space")
            demo.launch(server_name="0.0.0.0", share=False)
        else:
            logger.info("Running locally")
            demo.launch(server_name="0.0.0.0", share=False)
    except Exception as e:
        logger.critical(f"Failed to launch Gradio app: {e}")
        logger.debug(traceback.format_exc())