HiDream-I1-Dev / app.py
cai-qi's picture
Update app.py
a53dbd0 verified
raw
history blame
21.8 kB
import logging
import os
import random
import time
import traceback
from io import BytesIO
import gradio as gr
import requests
from PIL import Image, PngImagePlugin
from dotenv import load_dotenv
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# API Configuration
API_TOKEN = os.environ.get("HIDREAM_API_TOKEN")
API_REQUEST_URL = os.environ.get("API_REQUEST_URL")
API_RESULT_URL = os.environ.get("API_RESULT_URL")
API_IMAGE_URL = os.environ.get("API_IMAGE_URL")
API_VERSION = os.environ.get("API_VERSION")
API_MODEL_NAME = os.environ.get("API_MODEL_NAME")
MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT"))
POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL"))
MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME"))
# Resolution options
ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
class APIError(Exception):
"""Custom exception for API-related errors"""
pass
def create_request(prompt, aspect_ratio="1:1", seed=-1):
"""
Create an image generation request to the API.
Args:
prompt (str): Text prompt describing the image to generate
aspect_ratio (str): Aspect ratio of the output image
seed (int): Seed for reproducibility, -1 for random
Returns:
tuple: (task_id, seed) - Task ID if successful and the seed used
Raises:
APIError: If the API request fails
"""
if not prompt or not prompt.strip():
raise ValueError("Prompt cannot be empty")
# Validate aspect ratio
if aspect_ratio not in ASPECT_RATIO_OPTIONS:
raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(ASPECT_RATIO_OPTIONS)}")
# Generate random seed if not provided
if seed == -1:
seed = random.randint(1, 2147483647)
# Validate seed
try:
seed = int(seed)
if seed < -1 or seed > 2147483647:
raise ValueError("Seed must be -1 or between 0 and 2147483647")
except (TypeError, ValueError):
raise ValueError("Seed must be an integer")
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
"Content-Type": "application/json",
}
generate_data = {
"module": "txt2img",
"prompt": prompt,
"params": {
"batch_size": 1,
"wh_ratio": aspect_ratio,
"seed": seed
},
"version": API_VERSION,
}
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Sending API request for prompt: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")
response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
raise APIError("Invalid response format from API")
task_id = result.get("result", {}).get("task_id")
if not task_id:
raise APIError("No task ID returned from API")
logger.info(f"Successfully created task with ID: {task_id}")
return task_id, seed
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
error_message = f"HTTP error {status_code}"
if status_code == 401:
raise APIError("Authentication failed. Please check your API token.")
elif status_code == 429:
retry_count += 1
wait_time = min(2 ** retry_count, 10) # Exponential backoff
logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry...")
time.sleep(wait_time)
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message += f": {error_detail.get('message', 'Client error')}"
except:
pass
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
logger.error(f"Request error: {str(e)}")
raise APIError(f"Failed to connect to API: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
raise APIError(f"Unexpected error: {str(e)}")
raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
def get_results(task_id):
"""
Check the status of an image generation task.
Args:
task_id (str): The task ID to check
Returns:
dict: Task result information
Raises:
APIError: If the API request fails
"""
if not task_id:
raise ValueError("Task ID cannot be empty")
url = f"{API_RESULT_URL}?task_id={task_id}"
headers = {
"Authorization": f"Bearer {API_TOKEN}",
"X-accept-language": "en",
}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
result = response.json()
if not result or "result" not in result:
raise APIError("Invalid response format from API")
return result
except requests.exceptions.Timeout:
logger.warning(f"Request timed out when checking task {task_id}")
return None
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
if status_code == 401:
raise APIError("Authentication failed. Please check your API token.")
elif 400 <= status_code < 500:
try:
error_detail = e.response.json()
error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}"
except:
error_message = f"HTTP error {status_code}"
logger.error(error_message)
return None
else:
logger.warning(f"Server error {status_code} when checking task {task_id}")
return None
except requests.exceptions.RequestException as e:
logger.warning(f"Network error when checking task {task_id}: {str(e)}")
return None
except Exception as e:
logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
return None
def download_image(image_url):
"""
Download an image from a URL and return it as a PIL Image.
Converts WebP to PNG format while preserving original image data.
Args:
image_url (str): URL of the image
Returns:
PIL.Image: Downloaded image object converted to PNG format
Raises:
APIError: If the download fails
"""
if not image_url:
raise ValueError("Image URL cannot be empty")
retry_count = 0
while retry_count < MAX_RETRY_COUNT:
try:
logger.info(f"Downloading image from {image_url}")
response = requests.get(image_url, timeout=15)
response.raise_for_status()
# Open the image from response content
image = Image.open(BytesIO(response.content))
# Get original metadata before conversion
original_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
original_metadata[key] = value
# Convert to PNG regardless of original format (WebP, JPEG, etc.)
if image.format != 'PNG':
logger.info(f"Converting image from {image.format} to PNG format")
png_buffer = BytesIO()
# If the image has an alpha channel, preserve it, otherwise convert to RGB
if 'A' in image.getbands():
image_to_save = image
else:
image_to_save = image.convert('RGB')
image_to_save.save(png_buffer, format='PNG')
png_buffer.seek(0)
image = Image.open(png_buffer)
# Preserve original metadata
for key, value in original_metadata.items():
image.info[key] = value
logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
return image
except requests.exceptions.Timeout:
retry_count += 1
logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
if 400 <= status_code < 500:
error_message = f"HTTP error {status_code} when downloading image"
logger.error(error_message)
raise APIError(error_message)
else:
retry_count += 1
logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except requests.exceptions.RequestException as e:
retry_count += 1
logger.warning(f"Network error: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
time.sleep(1)
except Exception as e:
logger.error(f"Error processing image: {str(e)}\n{traceback.format_exc()}")
raise APIError(f"Failed to process image: {str(e)}")
raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
def add_metadata_to_image(image, metadata):
"""
Add metadata to a PIL image.
Args:
image (PIL.Image): The image to add metadata to
metadata (dict): Metadata to add to the image
Returns:
PIL.Image: Image with metadata
"""
if not image:
return None
try:
# Get any existing metadata
existing_metadata = {}
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
existing_metadata[key] = value
# Merge with new metadata (new values override existing ones)
all_metadata = {**existing_metadata, **metadata}
# Create a new metadata dictionary for PNG
meta = PngImagePlugin.PngInfo()
# Add each metadata item
for key, value in all_metadata.items():
meta.add_text(key, str(value))
# Save with metadata to a buffer
buffer = BytesIO()
image.save(buffer, format='PNG', pnginfo=meta)
# Reload the image from the buffer
buffer.seek(0)
return Image.open(buffer)
except Exception as e:
logger.error(f"Failed to add metadata to image: {str(e)}\n{traceback.format_exc()}")
return image # Return original image if metadata addition fails
# Create Gradio interface
def create_ui():
with gr.Blocks(title="HiDream-I1-Dev Image Generator", theme=gr.themes.Soft()) as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=4):
gr.Markdown("""
# HiDream-I1-Dev Image Generator
Generate high-quality images from text descriptions using state-of-the-art AI
[πŸ€— HuggingFace](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
[GitHub](https://github.com/HiDream-ai/HiDream-I1) |
[Twitter](https://x.com/vivago_ai)
<span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
lines=3
)
with gr.Row():
aspect_ratio = gr.Radio(
choices=ASPECT_RATIO_OPTIONS,
value=ASPECT_RATIO_OPTIONS[2],
label="Aspect Ratio",
info="Select image aspect ratio"
)
seed = gr.Number(
label="Seed (use -1 for random)",
value=82706,
precision=0
)
with gr.Row():
generate_btn = gr.Button("Generate Image", variant="primary")
clear_btn = gr.Button("Clear")
seed_used = gr.Number(label="Seed Used", interactive=False)
status_msg = gr.Markdown("Status: Ready")
progress = gr.Progress(track_tqdm=False)
with gr.Column(scale=1):
output_image = gr.Image(label="Generated Image", format="png", type="pil", interactive=False)
with gr.Accordion("Image Information", open=False):
image_info = gr.JSON(label="Details")
# Status message update function
def update_status(step):
return f"Status: {step}"
# Generate function with status updates
def generate_with_status(prompt, aspect_ratio, seed, progress=gr.Progress()):
status_update = "Sending request to API..."
yield None, seed, status_update, None
try:
if not prompt.strip():
status_update = "Error: Prompt cannot be empty"
yield None, seed, status_update, None
return
# Create request
task_id, used_seed = create_request(prompt, aspect_ratio, seed)
status_update = f"Request sent. Task ID: {task_id}. Waiting for results..."
yield None, used_seed, status_update, None
# Poll for results
start_time = time.time()
last_completion_ratio = 0
progress(0, desc="Initializing...")
while time.time() - start_time < MAX_POLL_TIME:
result = get_results(task_id)
if not result:
time.sleep(POLL_INTERVAL)
continue
sub_results = result.get("result", {}).get("sub_task_results", [])
if not sub_results:
time.sleep(POLL_INTERVAL)
continue
status = sub_results[0].get("task_status")
# Get and display completion ratio
completion_ratio = sub_results[0].get('task_completion', 0) * 100
if completion_ratio != last_completion_ratio:
# Only update UI when completion ratio changes
last_completion_ratio = completion_ratio
progress_bar = "β–ˆ" * int(completion_ratio / 10) + "β–‘" * (10 - int(completion_ratio / 10))
status_update = f"Generating image: {completion_ratio}% complete"
progress(completion_ratio / 100, desc=f"Generating image")
yield None, used_seed, status_update, None
# Check task status
if status == 1: # Success
progress(1.0, desc="Generation complete")
image_name = sub_results[0].get("image")
if not image_name:
status_update = "Error: No image name in successful response"
yield None, used_seed, status_update, None
return
status_update = "Downloading generated image..."
yield None, used_seed, status_update, None
image_url = f"{API_IMAGE_URL}{image_name}.png"
image = download_image(image_url)
if image:
# Add metadata to the image
metadata = {
"prompt": prompt,
"seed": str(used_seed),
"model": API_MODEL_NAME,
"aspect_ratio": aspect_ratio,
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"generated_by": "HiDream-I1-Dev Generator"
}
image_with_metadata = add_metadata_to_image(image, metadata)
# Create info for display
info = {
"model": API_MODEL_NAME,
"prompt": prompt,
"seed": used_seed,
"aspect_ratio": aspect_ratio,
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S")
}
status_update = "Image generated successfully!"
yield image_with_metadata, used_seed, status_update, info
return
else:
status_update = "Error: Failed to download the generated image"
yield None, used_seed, status_update, None
return
elif status in {3, 4}: # Failed or Canceled
error_msg = sub_results[0].get("task_error", "Unknown error")
status_update = f"Error: Task failed with status {status}: {error_msg}"
yield None, used_seed, status_update, None
return
# Only update time elapsed if completion ratio didn't change
if completion_ratio == last_completion_ratio:
status_update = f"Waiting for image generation... {completion_ratio}% complete ({int(time.time() - start_time)}s elapsed)"
yield None, used_seed, status_update, None
time.sleep(POLL_INTERVAL)
status_update = f"Error: Timeout waiting for image generation after {MAX_POLL_TIME} seconds"
yield None, used_seed, status_update, None
except APIError as e:
status_update = f"API Error: {str(e)}"
yield None, seed, status_update, None
except ValueError as e:
status_update = f"Value Error: {str(e)}"
yield None, seed, status_update, None
except Exception as e:
status_update = f"Unexpected error: {str(e)}"
yield None, seed, status_update, None
# Set up event handlers
generate_btn.click(
fn=generate_with_status,
inputs=[prompt, aspect_ratio, seed],
outputs=[output_image, seed_used, status_msg, image_info]
)
def clear_outputs():
return None, -1, "Status: Ready", None
clear_btn.click(
fn=clear_outputs,
inputs=None,
outputs=[output_image, seed_used, status_msg, image_info]
)
# Examples
gr.Examples(
examples=[
[
"A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
"4:3", 82706],
[
"A modern art interpretation of a traditional landscape painting, using bold colors and abstract forms to represent mountains, rivers, and mist. Incorporate calligraphic elements and a sense of dynamic energy.",
"1:1", 661320],
[
"Intimate portrait of a young woman from a nomadic tribe in ancient China, wearing fur-trimmed clothing and intricate silver jewelry. Wind-swept hair and a resilient gaze. Background of a vast, open grassland under a dramatic sky.",
"1:1", 34235],
[
"Time-lapse concept: A single tree shown through four seasons simultaneously, spring blossoms, summer green, autumn colors, winter snow, blended seamlessly.",
"1:1", 241106]
],
inputs=[prompt, aspect_ratio, seed],
outputs=[output_image, seed_used, status_msg, image_info],
fn=generate_with_status,
cache_examples=False
)
return demo
# Launch app
if __name__ == "__main__":
demo = create_ui()
demo.queue(max_size=10, default_concurrency_limit=5).launch()