cai-qi commited on
Commit
a53dbd0
·
verified ·
1 Parent(s): 2d79e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +542 -167
app.py CHANGED
@@ -1,181 +1,556 @@
1
- import torch
 
 
 
 
 
 
2
  import gradio as gr
3
- from hi_diffusers import HiDreamImagePipeline
4
- from hi_diffusers import HiDreamImageTransformer2DModel
5
- from hi_diffusers.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler
6
- from hi_diffusers.schedulers.flash_flow_match import FlashFlowMatchEulerDiscreteScheduler
7
- from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
8
-
9
- MODEL_PREFIX = "HiDream-ai"
10
- LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
11
-
12
- # Model configurations
13
- MODEL_CONFIGS = {
14
- "dev": {
15
- "path": f"{MODEL_PREFIX}/HiDream-I1-Dev",
16
- "guidance_scale": 0.0,
17
- "num_inference_steps": 28,
18
- "shift": 6.0,
19
- "scheduler": FlashFlowMatchEulerDiscreteScheduler
20
- },
21
- "full": {
22
- "path": f"{MODEL_PREFIX}/HiDream-I1-Full",
23
- "guidance_scale": 5.0,
24
- "num_inference_steps": 50,
25
- "shift": 3.0,
26
- "scheduler": FlowUniPCMultistepScheduler
27
- },
28
- "fast": {
29
- "path": f"{MODEL_PREFIX}/HiDream-I1-Fast",
30
- "guidance_scale": 0.0,
31
- "num_inference_steps": 16,
32
- "shift": 3.0,
33
- "scheduler": FlashFlowMatchEulerDiscreteScheduler
34
- }
35
- }
36
 
37
  # Resolution options
38
- RESOLUTION_OPTIONS = [
39
- "1024 × 1024 (Square)",
40
- "768 × 1360 (Portrait)",
41
- "1360 × 768 (Landscape)",
42
- "880 × 1168 (Portrait)",
43
- "1168 × 880 (Landscape)",
44
- "1248 × 832 (Landscape)",
45
- "832 × 1248 (Portrait)"
46
- ]
47
-
48
- # Load models
49
- def load_models(model_type):
50
- config = MODEL_CONFIGS[model_type]
51
- pretrained_model_name_or_path = config["path"]
52
- scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=1000, shift=config["shift"], use_dynamic_shifting=False)
53
-
54
- tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
55
- LLAMA_MODEL_NAME,
56
- use_fast=False)
57
-
58
- text_encoder_4 = LlamaForCausalLM.from_pretrained(
59
- LLAMA_MODEL_NAME,
60
- output_hidden_states=True,
61
- output_attentions=True,
62
- torch_dtype=torch.bfloat16).to("cuda")
63
-
64
- transformer = HiDreamImageTransformer2DModel.from_pretrained(
65
- pretrained_model_name_or_path,
66
- subfolder="transformer",
67
- torch_dtype=torch.bfloat16).to("cuda")
68
-
69
- pipe = HiDreamImagePipeline.from_pretrained(
70
- pretrained_model_name_or_path,
71
- scheduler=scheduler,
72
- tokenizer_4=tokenizer_4,
73
- text_encoder_4=text_encoder_4,
74
- torch_dtype=torch.bfloat16
75
- ).to("cuda", torch.bfloat16)
76
- pipe.transformer = transformer
77
-
78
- return pipe, config
79
-
80
- # Parse resolution string to get height and width
81
- def parse_resolution(resolution_str):
82
- if "1024 × 1024" in resolution_str:
83
- return 1024, 1024
84
- elif "768 × 1360" in resolution_str:
85
- return 768, 1360
86
- elif "1360 × 768" in resolution_str:
87
- return 1360, 768
88
- elif "880 × 1168" in resolution_str:
89
- return 880, 1168
90
- elif "1168 × 880" in resolution_str:
91
- return 1168, 880
92
- elif "1248 × 832" in resolution_str:
93
- return 1248, 832
94
- elif "832 × 1248" in resolution_str:
95
- return 832, 1248
96
- else:
97
- return 1024, 1024 # Default fallback
98
-
99
- # Generate image function
100
- def generate_image(model_type, prompt, resolution, seed):
101
- global pipe, current_model
102
-
103
- # Get configuration for current model
104
- config = MODEL_CONFIGS[model_type]
105
- guidance_scale = config["guidance_scale"]
106
- num_inference_steps = config["num_inference_steps"]
107
-
108
- # Parse resolution
109
- height, width = parse_resolution(resolution)
110
 
111
- # Handle seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if seed == -1:
113
- seed = torch.randint(0, 1000000, (1,)).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- generator = torch.Generator("cuda").manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- images = pipe(
118
- prompt,
119
- height=height,
120
- width=width,
121
- guidance_scale=guidance_scale,
122
- num_inference_steps=num_inference_steps,
123
- num_images_per_prompt=1,
124
- generator=generator
125
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
- return images[0], seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # Initialize with default model
130
- print("Loading default model (full)...")
131
- current_model = "fast"
132
- pipe, _ = load_models(current_model)
133
- print("Model loaded successfully!")
134
 
135
  # Create Gradio interface
136
- with gr.Blocks(title="HiDream Image Generator") as demo:
137
- gr.Markdown("# HiDream Image Generator")
138
-
139
- with gr.Row():
140
- with gr.Column():
141
- model_type = gr.Radio(
142
- choices=list(MODEL_CONFIGS.keys()),
143
- value="full",
144
- label="Model Type",
145
- info="Select model variant"
146
- )
147
-
148
- prompt = gr.Textbox(
149
- label="Prompt",
150
- placeholder="A cat holding a sign that says \"Hi-Dreams.ai\".",
151
- lines=3
152
- )
153
-
154
- resolution = gr.Radio(
155
- choices=RESOLUTION_OPTIONS,
156
- value=RESOLUTION_OPTIONS[0],
157
- label="Resolution",
158
- info="Select image resolution"
159
- )
160
-
161
- seed = gr.Number(
162
- label="Seed (use -1 for random)",
163
- value=-1,
164
- precision=0
165
- )
166
-
167
- generate_btn = gr.Button("Generate Image")
168
- seed_used = gr.Number(label="Seed Used", interactive=False)
169
-
170
- with gr.Column():
171
- output_image = gr.Image(label="Generated Image", type="pil")
172
-
173
- generate_btn.click(
174
- fn=generate_image,
175
- inputs=[model_type, prompt, resolution, seed],
176
- outputs=[output_image, seed_used]
177
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
  # Launch app
180
  if __name__ == "__main__":
181
- demo.launch()
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import time
5
+ import traceback
6
+ from io import BytesIO
7
+
8
  import gradio as gr
9
+ import requests
10
+ from PIL import Image, PngImagePlugin
11
+ from dotenv import load_dotenv
12
+
13
+ # Set up logging
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Load environment variables
21
+ load_dotenv()
22
+
23
+ # API Configuration
24
+ API_TOKEN = os.environ.get("HIDREAM_API_TOKEN")
25
+ API_REQUEST_URL = os.environ.get("API_REQUEST_URL")
26
+ API_RESULT_URL = os.environ.get("API_RESULT_URL")
27
+ API_IMAGE_URL = os.environ.get("API_IMAGE_URL")
28
+ API_VERSION = os.environ.get("API_VERSION")
29
+ API_MODEL_NAME = os.environ.get("API_MODEL_NAME")
30
+ MAX_RETRY_COUNT = int(os.environ.get("MAX_RETRY_COUNT"))
31
+ POLL_INTERVAL = float(os.environ.get("POLL_INTERVAL"))
32
+ MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME"))
 
 
 
 
 
 
 
 
 
33
 
34
  # Resolution options
35
+ ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
36
+
37
+
38
+ class APIError(Exception):
39
+ """Custom exception for API-related errors"""
40
+ pass
41
+
42
+
43
+ def create_request(prompt, aspect_ratio="1:1", seed=-1):
44
+ """
45
+ Create an image generation request to the API.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ Args:
48
+ prompt (str): Text prompt describing the image to generate
49
+ aspect_ratio (str): Aspect ratio of the output image
50
+ seed (int): Seed for reproducibility, -1 for random
51
+
52
+ Returns:
53
+ tuple: (task_id, seed) - Task ID if successful and the seed used
54
+
55
+ Raises:
56
+ APIError: If the API request fails
57
+ """
58
+ if not prompt or not prompt.strip():
59
+ raise ValueError("Prompt cannot be empty")
60
+
61
+ # Validate aspect ratio
62
+ if aspect_ratio not in ASPECT_RATIO_OPTIONS:
63
+ raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(ASPECT_RATIO_OPTIONS)}")
64
+
65
+ # Generate random seed if not provided
66
  if seed == -1:
67
+ seed = random.randint(1, 2147483647)
68
+
69
+ # Validate seed
70
+ try:
71
+ seed = int(seed)
72
+ if seed < -1 or seed > 2147483647:
73
+ raise ValueError("Seed must be -1 or between 0 and 2147483647")
74
+ except (TypeError, ValueError):
75
+ raise ValueError("Seed must be an integer")
76
+
77
+ headers = {
78
+ "Authorization": f"Bearer {API_TOKEN}",
79
+ "X-accept-language": "en",
80
+ "Content-Type": "application/json",
81
+ }
82
+
83
+ generate_data = {
84
+ "module": "txt2img",
85
+ "prompt": prompt,
86
+ "params": {
87
+ "batch_size": 1,
88
+ "wh_ratio": aspect_ratio,
89
+ "seed": seed
90
+ },
91
+ "version": API_VERSION,
92
+ }
93
+
94
+ retry_count = 0
95
+ while retry_count < MAX_RETRY_COUNT:
96
+ try:
97
+ logger.info(f"Sending API request for prompt: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")
98
+ response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
99
+ response.raise_for_status()
100
+
101
+ result = response.json()
102
+ if not result or "result" not in result:
103
+ raise APIError("Invalid response format from API")
104
+
105
+ task_id = result.get("result", {}).get("task_id")
106
+ if not task_id:
107
+ raise APIError("No task ID returned from API")
108
+
109
+ logger.info(f"Successfully created task with ID: {task_id}")
110
+ return task_id, seed
111
+
112
+ except requests.exceptions.Timeout:
113
+ retry_count += 1
114
+ logger.warning(f"Request timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
115
+ time.sleep(1)
116
+
117
+ except requests.exceptions.HTTPError as e:
118
+ status_code = e.response.status_code
119
+ error_message = f"HTTP error {status_code}"
120
+
121
+ if status_code == 401:
122
+ raise APIError("Authentication failed. Please check your API token.")
123
+ elif status_code == 429:
124
+ retry_count += 1
125
+ wait_time = min(2 ** retry_count, 10) # Exponential backoff
126
+ logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry...")
127
+ time.sleep(wait_time)
128
+ elif 400 <= status_code < 500:
129
+ try:
130
+ error_detail = e.response.json()
131
+ error_message += f": {error_detail.get('message', 'Client error')}"
132
+ except:
133
+ pass
134
+ raise APIError(error_message)
135
+ else:
136
+ retry_count += 1
137
+ logger.warning(f"Server error: {error_message}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
138
+ time.sleep(1)
139
+
140
+ except requests.exceptions.RequestException as e:
141
+ logger.error(f"Request error: {str(e)}")
142
+ raise APIError(f"Failed to connect to API: {str(e)}")
143
+
144
+ except Exception as e:
145
+ logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
146
+ raise APIError(f"Unexpected error: {str(e)}")
147
+
148
+ raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
149
+
150
+
151
+ def get_results(task_id):
152
+ """
153
+ Check the status of an image generation task.
154
 
155
+ Args:
156
+ task_id (str): The task ID to check
157
+
158
+ Returns:
159
+ dict: Task result information
160
+
161
+ Raises:
162
+ APIError: If the API request fails
163
+ """
164
+ if not task_id:
165
+ raise ValueError("Task ID cannot be empty")
166
+
167
+ url = f"{API_RESULT_URL}?task_id={task_id}"
168
+ headers = {
169
+ "Authorization": f"Bearer {API_TOKEN}",
170
+ "X-accept-language": "en",
171
+ }
172
+
173
+ try:
174
+ response = requests.get(url, headers=headers, timeout=10)
175
+ response.raise_for_status()
176
+ result = response.json()
177
+
178
+ if not result or "result" not in result:
179
+ raise APIError("Invalid response format from API")
180
+
181
+ return result
182
+
183
+ except requests.exceptions.Timeout:
184
+ logger.warning(f"Request timed out when checking task {task_id}")
185
+ return None
186
+
187
+ except requests.exceptions.HTTPError as e:
188
+ status_code = e.response.status_code
189
+ if status_code == 401:
190
+ raise APIError("Authentication failed. Please check your API token.")
191
+ elif 400 <= status_code < 500:
192
+ try:
193
+ error_detail = e.response.json()
194
+ error_message = f"HTTP error {status_code}: {error_detail.get('message', 'Client error')}"
195
+ except:
196
+ error_message = f"HTTP error {status_code}"
197
+ logger.error(error_message)
198
+ return None
199
+ else:
200
+ logger.warning(f"Server error {status_code} when checking task {task_id}")
201
+ return None
202
+
203
+ except requests.exceptions.RequestException as e:
204
+ logger.warning(f"Network error when checking task {task_id}: {str(e)}")
205
+ return None
206
+
207
+ except Exception as e:
208
+ logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
209
+ return None
210
+
211
+
212
+ def download_image(image_url):
213
+ """
214
+ Download an image from a URL and return it as a PIL Image.
215
+ Converts WebP to PNG format while preserving original image data.
216
 
217
+ Args:
218
+ image_url (str): URL of the image
219
+
220
+ Returns:
221
+ PIL.Image: Downloaded image object converted to PNG format
222
+
223
+ Raises:
224
+ APIError: If the download fails
225
+ """
226
+ if not image_url:
227
+ raise ValueError("Image URL cannot be empty")
228
+
229
+ retry_count = 0
230
+ while retry_count < MAX_RETRY_COUNT:
231
+ try:
232
+ logger.info(f"Downloading image from {image_url}")
233
+ response = requests.get(image_url, timeout=15)
234
+ response.raise_for_status()
235
+
236
+ # Open the image from response content
237
+ image = Image.open(BytesIO(response.content))
238
+
239
+ # Get original metadata before conversion
240
+ original_metadata = {}
241
+ for key, value in image.info.items():
242
+ if isinstance(key, str) and isinstance(value, str):
243
+ original_metadata[key] = value
244
+
245
+ # Convert to PNG regardless of original format (WebP, JPEG, etc.)
246
+ if image.format != 'PNG':
247
+ logger.info(f"Converting image from {image.format} to PNG format")
248
+ png_buffer = BytesIO()
249
+
250
+ # If the image has an alpha channel, preserve it, otherwise convert to RGB
251
+ if 'A' in image.getbands():
252
+ image_to_save = image
253
+ else:
254
+ image_to_save = image.convert('RGB')
255
+
256
+ image_to_save.save(png_buffer, format='PNG')
257
+ png_buffer.seek(0)
258
+ image = Image.open(png_buffer)
259
+
260
+ # Preserve original metadata
261
+ for key, value in original_metadata.items():
262
+ image.info[key] = value
263
+
264
+ logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
265
+ return image
266
+
267
+ except requests.exceptions.Timeout:
268
+ retry_count += 1
269
+ logger.warning(f"Download timed out. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
270
+ time.sleep(1)
271
+
272
+ except requests.exceptions.HTTPError as e:
273
+ status_code = e.response.status_code
274
+ if 400 <= status_code < 500:
275
+ error_message = f"HTTP error {status_code} when downloading image"
276
+ logger.error(error_message)
277
+ raise APIError(error_message)
278
+ else:
279
+ retry_count += 1
280
+ logger.warning(f"Server error {status_code}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
281
+ time.sleep(1)
282
+
283
+ except requests.exceptions.RequestException as e:
284
+ retry_count += 1
285
+ logger.warning(f"Network error: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
286
+ time.sleep(1)
287
+
288
+ except Exception as e:
289
+ logger.error(f"Error processing image: {str(e)}\n{traceback.format_exc()}")
290
+ raise APIError(f"Failed to process image: {str(e)}")
291
+
292
+ raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
293
+
294
+
295
+ def add_metadata_to_image(image, metadata):
296
+ """
297
+ Add metadata to a PIL image.
298
 
299
+ Args:
300
+ image (PIL.Image): The image to add metadata to
301
+ metadata (dict): Metadata to add to the image
302
+
303
+ Returns:
304
+ PIL.Image: Image with metadata
305
+ """
306
+ if not image:
307
+ return None
308
+
309
+ try:
310
+ # Get any existing metadata
311
+ existing_metadata = {}
312
+ for key, value in image.info.items():
313
+ if isinstance(key, str) and isinstance(value, str):
314
+ existing_metadata[key] = value
315
+
316
+ # Merge with new metadata (new values override existing ones)
317
+ all_metadata = {**existing_metadata, **metadata}
318
+
319
+ # Create a new metadata dictionary for PNG
320
+ meta = PngImagePlugin.PngInfo()
321
+
322
+ # Add each metadata item
323
+ for key, value in all_metadata.items():
324
+ meta.add_text(key, str(value))
325
+
326
+ # Save with metadata to a buffer
327
+ buffer = BytesIO()
328
+ image.save(buffer, format='PNG', pnginfo=meta)
329
+
330
+ # Reload the image from the buffer
331
+ buffer.seek(0)
332
+ return Image.open(buffer)
333
+
334
+ except Exception as e:
335
+ logger.error(f"Failed to add metadata to image: {str(e)}\n{traceback.format_exc()}")
336
+ return image # Return original image if metadata addition fails
337
 
 
 
 
 
 
338
 
339
  # Create Gradio interface
340
+ def create_ui():
341
+ with gr.Blocks(title="HiDream-I1-Dev Image Generator", theme=gr.themes.Soft()) as demo:
342
+ with gr.Row(equal_height=True):
343
+ with gr.Column(scale=4):
344
+ gr.Markdown("""
345
+ # HiDream-I1-Dev Image Generator
346
+
347
+ Generate high-quality images from text descriptions using state-of-the-art AI
348
+
349
+ [🤗 HuggingFace](https://huggingface.co/HiDream-ai/HiDream-I1-Dev) |
350
+ [GitHub](https://github.com/HiDream-ai/HiDream-I1) |
351
+ [Twitter](https://x.com/vivago_ai)
352
+
353
+ <span style="color: #FF5733; font-weight: bold">For more features and to experience the full capabilities of our product, please visit [https://vivago.ai/](https://vivago.ai/).</span>
354
+ """)
355
+
356
+ with gr.Row():
357
+ with gr.Column(scale=1):
358
+ prompt = gr.Textbox(
359
+ label="Prompt",
360
+ placeholder="A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
361
+ lines=3
362
+ )
363
+
364
+ with gr.Row():
365
+ aspect_ratio = gr.Radio(
366
+ choices=ASPECT_RATIO_OPTIONS,
367
+ value=ASPECT_RATIO_OPTIONS[2],
368
+ label="Aspect Ratio",
369
+ info="Select image aspect ratio"
370
+ )
371
+
372
+ seed = gr.Number(
373
+ label="Seed (use -1 for random)",
374
+ value=82706,
375
+ precision=0
376
+ )
377
+
378
+ with gr.Row():
379
+ generate_btn = gr.Button("Generate Image", variant="primary")
380
+ clear_btn = gr.Button("Clear")
381
+
382
+ seed_used = gr.Number(label="Seed Used", interactive=False)
383
+ status_msg = gr.Markdown("Status: Ready")
384
+ progress = gr.Progress(track_tqdm=False)
385
+
386
+ with gr.Column(scale=1):
387
+ output_image = gr.Image(label="Generated Image", format="png", type="pil", interactive=False)
388
+
389
+ with gr.Accordion("Image Information", open=False):
390
+ image_info = gr.JSON(label="Details")
391
+
392
+ # Status message update function
393
+ def update_status(step):
394
+ return f"Status: {step}"
395
+
396
+ # Generate function with status updates
397
+ def generate_with_status(prompt, aspect_ratio, seed, progress=gr.Progress()):
398
+ status_update = "Sending request to API..."
399
+ yield None, seed, status_update, None
400
+
401
+ try:
402
+ if not prompt.strip():
403
+ status_update = "Error: Prompt cannot be empty"
404
+ yield None, seed, status_update, None
405
+ return
406
+
407
+ # Create request
408
+ task_id, used_seed = create_request(prompt, aspect_ratio, seed)
409
+ status_update = f"Request sent. Task ID: {task_id}. Waiting for results..."
410
+ yield None, used_seed, status_update, None
411
+
412
+ # Poll for results
413
+ start_time = time.time()
414
+ last_completion_ratio = 0
415
+ progress(0, desc="Initializing...")
416
+
417
+ while time.time() - start_time < MAX_POLL_TIME:
418
+ result = get_results(task_id)
419
+ if not result:
420
+ time.sleep(POLL_INTERVAL)
421
+ continue
422
+
423
+ sub_results = result.get("result", {}).get("sub_task_results", [])
424
+ if not sub_results:
425
+ time.sleep(POLL_INTERVAL)
426
+ continue
427
+
428
+ status = sub_results[0].get("task_status")
429
+
430
+ # Get and display completion ratio
431
+ completion_ratio = sub_results[0].get('task_completion', 0) * 100
432
+ if completion_ratio != last_completion_ratio:
433
+ # Only update UI when completion ratio changes
434
+ last_completion_ratio = completion_ratio
435
+ progress_bar = "█" * int(completion_ratio / 10) + "░" * (10 - int(completion_ratio / 10))
436
+ status_update = f"Generating image: {completion_ratio}% complete"
437
+ progress(completion_ratio / 100, desc=f"Generating image")
438
+ yield None, used_seed, status_update, None
439
+
440
+ # Check task status
441
+ if status == 1: # Success
442
+ progress(1.0, desc="Generation complete")
443
+ image_name = sub_results[0].get("image")
444
+ if not image_name:
445
+ status_update = "Error: No image name in successful response"
446
+ yield None, used_seed, status_update, None
447
+ return
448
+
449
+ status_update = "Downloading generated image..."
450
+ yield None, used_seed, status_update, None
451
+
452
+ image_url = f"{API_IMAGE_URL}{image_name}.png"
453
+ image = download_image(image_url)
454
+
455
+ if image:
456
+ # Add metadata to the image
457
+ metadata = {
458
+ "prompt": prompt,
459
+ "seed": str(used_seed),
460
+ "model": API_MODEL_NAME,
461
+ "aspect_ratio": aspect_ratio,
462
+ "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
463
+ "generated_by": "HiDream-I1-Dev Generator"
464
+ }
465
+
466
+ image_with_metadata = add_metadata_to_image(image, metadata)
467
+
468
+ # Create info for display
469
+ info = {
470
+ "model": API_MODEL_NAME,
471
+ "prompt": prompt,
472
+ "seed": used_seed,
473
+ "aspect_ratio": aspect_ratio,
474
+ "generated_at": time.strftime("%Y-%m-%d %H:%M:%S")
475
+ }
476
+
477
+ status_update = "Image generated successfully!"
478
+ yield image_with_metadata, used_seed, status_update, info
479
+ return
480
+ else:
481
+ status_update = "Error: Failed to download the generated image"
482
+ yield None, used_seed, status_update, None
483
+ return
484
+
485
+ elif status in {3, 4}: # Failed or Canceled
486
+ error_msg = sub_results[0].get("task_error", "Unknown error")
487
+ status_update = f"Error: Task failed with status {status}: {error_msg}"
488
+ yield None, used_seed, status_update, None
489
+ return
490
+
491
+ # Only update time elapsed if completion ratio didn't change
492
+ if completion_ratio == last_completion_ratio:
493
+ status_update = f"Waiting for image generation... {completion_ratio}% complete ({int(time.time() - start_time)}s elapsed)"
494
+ yield None, used_seed, status_update, None
495
+
496
+ time.sleep(POLL_INTERVAL)
497
+
498
+ status_update = f"Error: Timeout waiting for image generation after {MAX_POLL_TIME} seconds"
499
+ yield None, used_seed, status_update, None
500
+
501
+ except APIError as e:
502
+ status_update = f"API Error: {str(e)}"
503
+ yield None, seed, status_update, None
504
+
505
+ except ValueError as e:
506
+ status_update = f"Value Error: {str(e)}"
507
+ yield None, seed, status_update, None
508
+
509
+ except Exception as e:
510
+ status_update = f"Unexpected error: {str(e)}"
511
+ yield None, seed, status_update, None
512
+
513
+ # Set up event handlers
514
+ generate_btn.click(
515
+ fn=generate_with_status,
516
+ inputs=[prompt, aspect_ratio, seed],
517
+ outputs=[output_image, seed_used, status_msg, image_info]
518
+ )
519
+
520
+ def clear_outputs():
521
+ return None, -1, "Status: Ready", None
522
+
523
+ clear_btn.click(
524
+ fn=clear_outputs,
525
+ inputs=None,
526
+ outputs=[output_image, seed_used, status_msg, image_info]
527
+ )
528
+
529
+ # Examples
530
+ gr.Examples(
531
+ examples=[
532
+ [
533
+ "A vibrant and dynamic graffiti mural adorns a weathered brick wall in a bustling urban alleyway, a burst of color and energy amidst the city's grit. Boldly spray-painted letters declare \"HiDream.ai\" alongside other intricate street art designs, a testament to creative expression in the urban landscape.",
534
+ "4:3", 82706],
535
+ [
536
+ "A modern art interpretation of a traditional landscape painting, using bold colors and abstract forms to represent mountains, rivers, and mist. Incorporate calligraphic elements and a sense of dynamic energy.",
537
+ "1:1", 661320],
538
+ [
539
+ "Intimate portrait of a young woman from a nomadic tribe in ancient China, wearing fur-trimmed clothing and intricate silver jewelry. Wind-swept hair and a resilient gaze. Background of a vast, open grassland under a dramatic sky.",
540
+ "1:1", 34235],
541
+ [
542
+ "Time-lapse concept: A single tree shown through four seasons simultaneously, spring blossoms, summer green, autumn colors, winter snow, blended seamlessly.",
543
+ "1:1", 241106]
544
+ ],
545
+ inputs=[prompt, aspect_ratio, seed],
546
+ outputs=[output_image, seed_used, status_msg, image_info],
547
+ fn=generate_with_status,
548
+ cache_examples=False
549
+ )
550
+
551
+ return demo
552
 
553
  # Launch app
554
  if __name__ == "__main__":
555
+ demo = create_ui()
556
+ demo.queue(max_size=10, default_concurrency_limit=5).launch()