cai-qi commited on
Commit
a2e9868
·
verified ·
1 Parent(s): 0252f96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -19
app.py CHANGED
@@ -34,6 +34,10 @@ MAX_POLL_TIME = int(os.environ.get("MAX_POLL_TIME"))
34
  # Resolution options
35
  ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
36
 
 
 
 
 
37
 
38
  class APIError(Exception):
39
  """Custom exception for API-related errors"""
@@ -55,24 +59,31 @@ def create_request(prompt, aspect_ratio="1:1", seed=-1):
55
  Raises:
56
  APIError: If the API request fails
57
  """
 
 
58
  if not prompt or not prompt.strip():
 
59
  raise ValueError("Prompt cannot be empty")
60
 
61
  # Validate aspect ratio
62
  if aspect_ratio not in ASPECT_RATIO_OPTIONS:
 
63
  raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(ASPECT_RATIO_OPTIONS)}")
64
 
65
  # Generate random seed if not provided
66
  if seed == -1:
67
  seed = random.randint(1, 2147483647)
 
68
 
69
  # Validate seed
70
  try:
71
  seed = int(seed)
72
  if seed < -1 or seed > 2147483647:
73
- raise ValueError("Seed must be -1 or between 0 and 2147483647")
74
- except (TypeError, ValueError):
75
- raise ValueError("Seed must be an integer")
 
 
76
 
77
  headers = {
78
  "Authorization": f"Bearer {API_TOKEN}",
@@ -94,19 +105,25 @@ def create_request(prompt, aspect_ratio="1:1", seed=-1):
94
  retry_count = 0
95
  while retry_count < MAX_RETRY_COUNT:
96
  try:
97
- logger.info(f"Sending API request for prompt: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'")
98
  response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
 
 
 
 
99
  response.raise_for_status()
100
 
101
  result = response.json()
102
  if not result or "result" not in result:
103
- raise APIError("Invalid response format from API")
 
104
 
105
  task_id = result.get("result", {}).get("task_id")
106
  if not task_id:
107
- raise APIError("No task ID returned from API")
 
108
 
109
- logger.info(f"Successfully created task with ID: {task_id}")
110
  return task_id, seed
111
 
112
  except requests.exceptions.Timeout:
@@ -117,13 +134,21 @@ def create_request(prompt, aspect_ratio="1:1", seed=-1):
117
  except requests.exceptions.HTTPError as e:
118
  status_code = e.response.status_code
119
  error_message = f"HTTP error {status_code}"
 
 
 
 
 
 
 
120
 
121
  if status_code == 401:
 
122
  raise APIError("Authentication failed. Please check your API token.")
123
  elif status_code == 429:
124
  retry_count += 1
125
  wait_time = min(2 ** retry_count, 10) # Exponential backoff
126
- logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry...")
127
  time.sleep(wait_time)
128
  elif 400 <= status_code < 500:
129
  try:
@@ -131,6 +156,7 @@ def create_request(prompt, aspect_ratio="1:1", seed=-1):
131
  error_message += f": {error_detail.get('message', 'Client error')}"
132
  except:
133
  pass
 
134
  raise APIError(error_message)
135
  else:
136
  retry_count += 1
@@ -139,12 +165,15 @@ def create_request(prompt, aspect_ratio="1:1", seed=-1):
139
 
140
  except requests.exceptions.RequestException as e:
141
  logger.error(f"Request error: {str(e)}")
 
142
  raise APIError(f"Failed to connect to API: {str(e)}")
143
 
144
  except Exception as e:
145
- logger.error(f"Unexpected error: {str(e)}\n{traceback.format_exc()}")
 
146
  raise APIError(f"Unexpected error: {str(e)}")
147
 
 
148
  raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
149
 
150
 
@@ -161,7 +190,10 @@ def get_results(task_id):
161
  Raises:
162
  APIError: If the API request fails
163
  """
 
 
164
  if not task_id:
 
165
  raise ValueError("Task ID cannot be empty")
166
 
167
  url = f"{API_RESULT_URL}?task_id={task_id}"
@@ -172,11 +204,14 @@ def get_results(task_id):
172
 
173
  try:
174
  response = requests.get(url, headers=headers, timeout=10)
 
 
175
  response.raise_for_status()
176
  result = response.json()
177
 
178
  if not result or "result" not in result:
179
- raise APIError("Invalid response format from API")
 
180
 
181
  return result
182
 
@@ -186,8 +221,17 @@ def get_results(task_id):
186
 
187
  except requests.exceptions.HTTPError as e:
188
  status_code = e.response.status_code
 
 
 
 
 
 
 
 
189
  if status_code == 401:
190
- raise APIError("Authentication failed. Please check your API token.")
 
191
  elif 400 <= status_code < 500:
192
  try:
193
  error_detail = e.response.json()
@@ -202,10 +246,12 @@ def get_results(task_id):
202
 
203
  except requests.exceptions.RequestException as e:
204
  logger.warning(f"Network error when checking task {task_id}: {str(e)}")
 
205
  return None
206
 
207
  except Exception as e:
208
  logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
 
209
  return None
210
 
211
 
@@ -223,24 +269,33 @@ def download_image(image_url):
223
  Raises:
224
  APIError: If the download fails
225
  """
 
 
226
  if not image_url:
227
- raise ValueError("Image URL cannot be empty")
 
228
 
229
  retry_count = 0
230
  while retry_count < MAX_RETRY_COUNT:
231
  try:
232
- logger.info(f"Downloading image from {image_url}")
233
  response = requests.get(image_url, timeout=15)
 
 
 
234
  response.raise_for_status()
235
 
236
  # Open the image from response content
237
  image = Image.open(BytesIO(response.content))
 
238
 
239
  # Get original metadata before conversion
240
  original_metadata = {}
241
  for key, value in image.info.items():
242
  if isinstance(key, str) and isinstance(value, str):
243
  original_metadata[key] = value
 
 
244
 
245
  # Convert to PNG regardless of original format (WebP, JPEG, etc.)
246
  if image.format != 'PNG':
@@ -249,17 +304,21 @@ def download_image(image_url):
249
 
250
  # If the image has an alpha channel, preserve it, otherwise convert to RGB
251
  if 'A' in image.getbands():
 
252
  image_to_save = image
253
  else:
 
254
  image_to_save = image.convert('RGB')
255
 
256
  image_to_save.save(png_buffer, format='PNG')
257
  png_buffer.seek(0)
258
  image = Image.open(png_buffer)
 
259
 
260
  # Preserve original metadata
261
  for key, value in original_metadata.items():
262
  image.info[key] = value
 
263
 
264
  logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
265
  return image
@@ -271,6 +330,14 @@ def download_image(image_url):
271
 
272
  except requests.exceptions.HTTPError as e:
273
  status_code = e.response.status_code
 
 
 
 
 
 
 
 
274
  if 400 <= status_code < 500:
275
  error_message = f"HTTP error {status_code} when downloading image"
276
  logger.error(error_message)
@@ -282,13 +349,16 @@ def download_image(image_url):
282
 
283
  except requests.exceptions.RequestException as e:
284
  retry_count += 1
285
- logger.warning(f"Network error: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
 
286
  time.sleep(1)
287
 
288
  except Exception as e:
289
- logger.error(f"Error processing image: {str(e)}\n{traceback.format_exc()}")
 
290
  raise APIError(f"Failed to process image: {str(e)}")
291
 
 
292
  raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
293
 
294
 
@@ -303,7 +373,10 @@ def add_metadata_to_image(image, metadata):
303
  Returns:
304
  PIL.Image: Image with metadata
305
  """
 
 
306
  if not image:
 
307
  return None
308
 
309
  try:
@@ -312,9 +385,12 @@ def add_metadata_to_image(image, metadata):
312
  for key, value in image.info.items():
313
  if isinstance(key, str) and isinstance(value, str):
314
  existing_metadata[key] = value
 
 
315
 
316
  # Merge with new metadata (new values override existing ones)
317
  all_metadata = {**existing_metadata, **metadata}
 
318
 
319
  # Create a new metadata dictionary for PNG
320
  meta = PngImagePlugin.PngInfo()
@@ -326,18 +402,23 @@ def add_metadata_to_image(image, metadata):
326
  # Save with metadata to a buffer
327
  buffer = BytesIO()
328
  image.save(buffer, format='PNG', pnginfo=meta)
 
329
 
330
  # Reload the image from the buffer
331
  buffer.seek(0)
332
- return Image.open(buffer)
 
 
333
 
334
  except Exception as e:
335
- logger.error(f"Failed to add metadata to image: {str(e)}\n{traceback.format_exc()}")
 
336
  return image # Return original image if metadata addition fails
337
 
338
 
339
  # Create Gradio interface
340
  def create_ui():
 
341
  with gr.Blocks(title="HiDream-I1-Dev Image Generator", theme=gr.themes.Soft()) as demo:
342
  with gr.Row(equal_height=True):
343
  with gr.Column(scale=4):
@@ -395,16 +476,20 @@ def create_ui():
395
 
396
  # Generate function with status updates
397
  def generate_with_status(prompt, aspect_ratio, seed, progress=gr.Progress()):
 
 
398
  status_update = "Sending request to API..."
399
  yield None, seed, status_update, None
400
 
401
  try:
402
  if not prompt.strip():
 
403
  status_update = "Error: Prompt cannot be empty"
404
  yield None, seed, status_update, None
405
  return
406
 
407
  # Create request
 
408
  task_id, used_seed = create_request(prompt, aspect_ratio, seed)
409
  status_update = f"Request sent. Task ID: {task_id}. Waiting for results..."
410
  yield None, used_seed, status_update, None
@@ -413,19 +498,26 @@ def create_ui():
413
  start_time = time.time()
414
  last_completion_ratio = 0
415
  progress(0, desc="Initializing...")
 
416
 
417
  while time.time() - start_time < MAX_POLL_TIME:
 
 
 
418
  result = get_results(task_id)
419
  if not result:
 
420
  time.sleep(POLL_INTERVAL)
421
  continue
422
 
423
  sub_results = result.get("result", {}).get("sub_task_results", [])
424
  if not sub_results:
 
425
  time.sleep(POLL_INTERVAL)
426
  continue
427
 
428
  status = sub_results[0].get("task_status")
 
429
 
430
  # Get and display completion ratio
431
  completion_ratio = sub_results[0].get('task_completion', 0) * 100
@@ -435,13 +527,16 @@ def create_ui():
435
  progress_bar = "█" * int(completion_ratio / 10) + "░" * (10 - int(completion_ratio / 10))
436
  status_update = f"Generating image: {completion_ratio}% complete"
437
  progress(completion_ratio / 100, desc=f"Generating image")
 
438
  yield None, used_seed, status_update, None
439
 
440
  # Check task status
441
  if status == 1: # Success
 
442
  progress(1.0, desc="Generation complete")
443
  image_name = sub_results[0].get("image")
444
  if not image_name:
 
445
  status_update = "Error: No image name in successful response"
446
  yield None, used_seed, status_update, None
447
  return
@@ -450,10 +545,12 @@ def create_ui():
450
  yield None, used_seed, status_update, None
451
 
452
  image_url = f"{API_IMAGE_URL}{image_name}.png"
 
453
  image = download_image(image_url)
454
 
455
  if image:
456
  # Add metadata to the image
 
457
  metadata = {
458
  "prompt": prompt,
459
  "seed": str(used_seed),
@@ -475,15 +572,18 @@ def create_ui():
475
  }
476
 
477
  status_update = "Image generated successfully!"
 
478
  yield image_with_metadata, used_seed, status_update, info
479
  return
480
  else:
 
481
  status_update = "Error: Failed to download the generated image"
482
  yield None, used_seed, status_update, None
483
  return
484
 
485
  elif status in {3, 4}: # Failed or Canceled
486
  error_msg = sub_results[0].get("task_error", "Unknown error")
 
487
  status_update = f"Error: Task failed with status {status}: {error_msg}"
488
  yield None, used_seed, status_update, None
489
  return
@@ -495,18 +595,23 @@ def create_ui():
495
 
496
  time.sleep(POLL_INTERVAL)
497
 
 
498
  status_update = f"Error: Timeout waiting for image generation after {MAX_POLL_TIME} seconds"
499
  yield None, used_seed, status_update, None
500
 
501
  except APIError as e:
 
502
  status_update = f"API Error: {str(e)}"
503
  yield None, seed, status_update, None
504
 
505
  except ValueError as e:
 
506
  status_update = f"Value Error: {str(e)}"
507
  yield None, seed, status_update, None
508
 
509
  except Exception as e:
 
 
510
  status_update = f"Unexpected error: {str(e)}"
511
  yield None, seed, status_update, None
512
 
@@ -518,6 +623,7 @@ def create_ui():
518
  )
519
 
520
  def clear_outputs():
 
521
  return None, -1, "Status: Ready", None
522
 
523
  clear_btn.click(
@@ -545,12 +651,17 @@ def create_ui():
545
  inputs=[prompt, aspect_ratio, seed],
546
  outputs=[output_image, seed_used, status_msg, image_info],
547
  fn=generate_with_status,
548
- cache_examples=True
 
549
  )
550
-
 
551
  return demo
552
 
553
  # Launch app
554
  if __name__ == "__main__":
 
555
  demo = create_ui()
 
556
  demo.queue(max_size=10, default_concurrency_limit=5).launch()
 
 
34
  # Resolution options
35
  ASPECT_RATIO_OPTIONS = ["1:1", "3:4", "4:3", "9:16", "16:9"]
36
 
37
+ # Log configuration details
38
+ logger.info(f"API configuration loaded: REQUEST_URL={API_REQUEST_URL}, RESULT_URL={API_RESULT_URL}, VERSION={API_VERSION}, MODEL={API_MODEL_NAME}")
39
+ logger.info(f"Retry configuration: MAX_RETRY_COUNT={MAX_RETRY_COUNT}, POLL_INTERVAL={POLL_INTERVAL}s, MAX_POLL_TIME={MAX_POLL_TIME}s")
40
+
41
 
42
  class APIError(Exception):
43
  """Custom exception for API-related errors"""
 
59
  Raises:
60
  APIError: If the API request fails
61
  """
62
+ logger.info(f"Starting create_request with prompt='{prompt[:50]}...', aspect_ratio={aspect_ratio}, seed={seed}")
63
+
64
  if not prompt or not prompt.strip():
65
+ logger.error("Empty prompt provided to create_request")
66
  raise ValueError("Prompt cannot be empty")
67
 
68
  # Validate aspect ratio
69
  if aspect_ratio not in ASPECT_RATIO_OPTIONS:
70
+ logger.error(f"Invalid aspect ratio: {aspect_ratio}. Valid options: {', '.join(ASPECT_RATIO_OPTIONS)}")
71
  raise ValueError(f"Invalid aspect ratio. Must be one of: {', '.join(ASPECT_RATIO_OPTIONS)}")
72
 
73
  # Generate random seed if not provided
74
  if seed == -1:
75
  seed = random.randint(1, 2147483647)
76
+ logger.info(f"Generated random seed: {seed}")
77
 
78
  # Validate seed
79
  try:
80
  seed = int(seed)
81
  if seed < -1 or seed > 2147483647:
82
+ logger.error(f"Invalid seed value: {seed}")
83
+ raise ValueError(f"Seed must be -1 or between 0 and 2147483647 but got {seed}")
84
+ except (TypeError, ValueError) as e:
85
+ logger.error(f"Seed validation failed: {str(e)}")
86
+ raise ValueError(f"Seed must be an integer but got {seed}")
87
 
88
  headers = {
89
  "Authorization": f"Bearer {API_TOKEN}",
 
105
  retry_count = 0
106
  while retry_count < MAX_RETRY_COUNT:
107
  try:
108
+ logger.info(f"Sending API request [attempt {retry_count+1}/{MAX_RETRY_COUNT}] for prompt: '{prompt[:50]}...'")
109
  response = requests.post(API_REQUEST_URL, json=generate_data, headers=headers, timeout=10)
110
+
111
+ # Log response status code
112
+ logger.info(f"API request response status: {response.status_code}")
113
+
114
  response.raise_for_status()
115
 
116
  result = response.json()
117
  if not result or "result" not in result:
118
+ logger.error(f"Invalid API response format: {str(result)}")
119
+ raise APIError(f"Invalid response format from API when sending request: {str(result)}")
120
 
121
  task_id = result.get("result", {}).get("task_id")
122
  if not task_id:
123
+ logger.error(f"No task ID in API response: {str(result)}")
124
+ raise APIError(f"No task ID returned from API: {str(result)}")
125
 
126
+ logger.info(f"Successfully created task with ID: {task_id}, seed: {seed}")
127
  return task_id, seed
128
 
129
  except requests.exceptions.Timeout:
 
134
  except requests.exceptions.HTTPError as e:
135
  status_code = e.response.status_code
136
  error_message = f"HTTP error {status_code}"
137
+
138
+ try:
139
+ error_detail = e.response.json()
140
+ error_message += f": {error_detail}"
141
+ logger.error(f"API response error content: {error_detail}")
142
+ except:
143
+ logger.error(f"Could not parse API error response as JSON. Raw content: {e.response.content[:500]}")
144
 
145
  if status_code == 401:
146
+ logger.error(f"Authentication failed with API token. Status code: {status_code}")
147
  raise APIError("Authentication failed. Please check your API token.")
148
  elif status_code == 429:
149
  retry_count += 1
150
  wait_time = min(2 ** retry_count, 10) # Exponential backoff
151
+ logger.warning(f"Rate limit exceeded. Waiting {wait_time}s before retry ({retry_count}/{MAX_RETRY_COUNT})...")
152
  time.sleep(wait_time)
153
  elif 400 <= status_code < 500:
154
  try:
 
156
  error_message += f": {error_detail.get('message', 'Client error')}"
157
  except:
158
  pass
159
+ logger.error(f"Client error: {error_message}, Prompt: '{prompt[:50]}...', Status: {status_code}")
160
  raise APIError(error_message)
161
  else:
162
  retry_count += 1
 
165
 
166
  except requests.exceptions.RequestException as e:
167
  logger.error(f"Request error: {str(e)}")
168
+ logger.debug(f"Request error details: {traceback.format_exc()}")
169
  raise APIError(f"Failed to connect to API: {str(e)}")
170
 
171
  except Exception as e:
172
+ logger.error(f"Unexpected error in create_request: {str(e)}")
173
+ logger.error(f"Full traceback: {traceback.format_exc()}")
174
  raise APIError(f"Unexpected error: {str(e)}")
175
 
176
+ logger.error(f"Failed to create request after {MAX_RETRY_COUNT} retries for prompt: '{prompt[:50]}...'")
177
  raise APIError(f"Failed after {MAX_RETRY_COUNT} retries")
178
 
179
 
 
190
  Raises:
191
  APIError: If the API request fails
192
  """
193
+ logger.debug(f"Checking status for task ID: {task_id}")
194
+
195
  if not task_id:
196
+ logger.error("Empty task ID provided to get_results")
197
  raise ValueError("Task ID cannot be empty")
198
 
199
  url = f"{API_RESULT_URL}?task_id={task_id}"
 
204
 
205
  try:
206
  response = requests.get(url, headers=headers, timeout=10)
207
+ logger.debug(f"Status check response code: {response.status_code}")
208
+
209
  response.raise_for_status()
210
  result = response.json()
211
 
212
  if not result or "result" not in result:
213
+ logger.warning(f"Invalid response format from API when checking task {task_id}: {str(result)}")
214
+ raise APIError(f"Invalid response format from API when checking task {task_id}: {str(result)}")
215
 
216
  return result
217
 
 
221
 
222
  except requests.exceptions.HTTPError as e:
223
  status_code = e.response.status_code
224
+ logger.warning(f"HTTP error {status_code} when checking task {task_id}")
225
+
226
+ try:
227
+ error_content = e.response.json()
228
+ logger.error(f"Error response content: {error_content}")
229
+ except:
230
+ logger.error(f"Could not parse error response as JSON. Raw content: {e.response.content[:500]}")
231
+
232
  if status_code == 401:
233
+ logger.error(f"Authentication failed when checking task {task_id}")
234
+ raise APIError(f"Authentication failed. Please check your API token when checking task {task_id}")
235
  elif 400 <= status_code < 500:
236
  try:
237
  error_detail = e.response.json()
 
246
 
247
  except requests.exceptions.RequestException as e:
248
  logger.warning(f"Network error when checking task {task_id}: {str(e)}")
249
+ logger.debug(f"Network error details: {traceback.format_exc()}")
250
  return None
251
 
252
  except Exception as e:
253
  logger.error(f"Unexpected error when checking task {task_id}: {str(e)}")
254
+ logger.error(f"Full traceback: {traceback.format_exc()}")
255
  return None
256
 
257
 
 
269
  Raises:
270
  APIError: If the download fails
271
  """
272
+ logger.info(f"Starting download_image from URL: {image_url}")
273
+
274
  if not image_url:
275
+ logger.error("Empty image URL provided to download_image")
276
+ raise ValueError("Image URL cannot be empty when downloading image")
277
 
278
  retry_count = 0
279
  while retry_count < MAX_RETRY_COUNT:
280
  try:
281
+ logger.info(f"Downloading image [attempt {retry_count+1}/{MAX_RETRY_COUNT}] from {image_url}")
282
  response = requests.get(image_url, timeout=15)
283
+
284
+ logger.debug(f"Image download response status: {response.status_code}, Content-Type: {response.headers.get('Content-Type')}, Content-Length: {response.headers.get('Content-Length')}")
285
+
286
  response.raise_for_status()
287
 
288
  # Open the image from response content
289
  image = Image.open(BytesIO(response.content))
290
+ logger.info(f"Image opened successfully. Format: {image.format}, Size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
291
 
292
  # Get original metadata before conversion
293
  original_metadata = {}
294
  for key, value in image.info.items():
295
  if isinstance(key, str) and isinstance(value, str):
296
  original_metadata[key] = value
297
+
298
+ logger.debug(f"Original image metadata: {original_metadata}")
299
 
300
  # Convert to PNG regardless of original format (WebP, JPEG, etc.)
301
  if image.format != 'PNG':
 
304
 
305
  # If the image has an alpha channel, preserve it, otherwise convert to RGB
306
  if 'A' in image.getbands():
307
+ logger.debug("Preserving alpha channel in image conversion")
308
  image_to_save = image
309
  else:
310
+ logger.debug("Converting image to RGB mode")
311
  image_to_save = image.convert('RGB')
312
 
313
  image_to_save.save(png_buffer, format='PNG')
314
  png_buffer.seek(0)
315
  image = Image.open(png_buffer)
316
+ logger.debug(f"Image converted to PNG. New size: {image.size[0]}x{image.size[1]}, Mode: {image.mode}")
317
 
318
  # Preserve original metadata
319
  for key, value in original_metadata.items():
320
  image.info[key] = value
321
+ logger.debug("Original metadata preserved in converted image")
322
 
323
  logger.info(f"Successfully downloaded and processed image: {image.size[0]}x{image.size[1]}")
324
  return image
 
330
 
331
  except requests.exceptions.HTTPError as e:
332
  status_code = e.response.status_code
333
+ logger.error(f"HTTP error {status_code} when downloading image from {image_url}")
334
+
335
+ try:
336
+ error_content = e.response.text[:500]
337
+ logger.error(f"Error response content: {error_content}")
338
+ except:
339
+ logger.error("Could not read error response content")
340
+
341
  if 400 <= status_code < 500:
342
  error_message = f"HTTP error {status_code} when downloading image"
343
  logger.error(error_message)
 
349
 
350
  except requests.exceptions.RequestException as e:
351
  retry_count += 1
352
+ logger.warning(f"Network error during image download: {str(e)}. Retrying ({retry_count}/{MAX_RETRY_COUNT})...")
353
+ logger.debug(f"Network error details: {traceback.format_exc()}")
354
  time.sleep(1)
355
 
356
  except Exception as e:
357
+ logger.error(f"Error processing image from {image_url}: {str(e)}")
358
+ logger.error(f"Full traceback: {traceback.format_exc()}")
359
  raise APIError(f"Failed to process image: {str(e)}")
360
 
361
+ logger.error(f"Failed to download image from {image_url} after {MAX_RETRY_COUNT} retries")
362
  raise APIError(f"Failed to download image after {MAX_RETRY_COUNT} retries")
363
 
364
 
 
373
  Returns:
374
  PIL.Image: Image with metadata
375
  """
376
+ logger.debug(f"Adding metadata to image: {metadata}")
377
+
378
  if not image:
379
+ logger.error("Null image provided to add_metadata_to_image")
380
  return None
381
 
382
  try:
 
385
  for key, value in image.info.items():
386
  if isinstance(key, str) and isinstance(value, str):
387
  existing_metadata[key] = value
388
+
389
+ logger.debug(f"Existing image metadata: {existing_metadata}")
390
 
391
  # Merge with new metadata (new values override existing ones)
392
  all_metadata = {**existing_metadata, **metadata}
393
+ logger.debug(f"Combined metadata: {all_metadata}")
394
 
395
  # Create a new metadata dictionary for PNG
396
  meta = PngImagePlugin.PngInfo()
 
402
  # Save with metadata to a buffer
403
  buffer = BytesIO()
404
  image.save(buffer, format='PNG', pnginfo=meta)
405
+ logger.debug("Image saved to buffer with metadata")
406
 
407
  # Reload the image from the buffer
408
  buffer.seek(0)
409
+ result_image = Image.open(buffer)
410
+ logger.debug("Image reloaded from buffer with metadata")
411
+ return result_image
412
 
413
  except Exception as e:
414
+ logger.error(f"Failed to add metadata to image: {str(e)}")
415
+ logger.error(f"Full traceback: {traceback.format_exc()}")
416
  return image # Return original image if metadata addition fails
417
 
418
 
419
  # Create Gradio interface
420
  def create_ui():
421
+ logger.info("Creating Gradio UI")
422
  with gr.Blocks(title="HiDream-I1-Dev Image Generator", theme=gr.themes.Soft()) as demo:
423
  with gr.Row(equal_height=True):
424
  with gr.Column(scale=4):
 
476
 
477
  # Generate function with status updates
478
  def generate_with_status(prompt, aspect_ratio, seed, progress=gr.Progress()):
479
+ logger.info(f"Starting image generation with prompt='{prompt[:50]}...', aspect_ratio={aspect_ratio}, seed={seed}")
480
+
481
  status_update = "Sending request to API..."
482
  yield None, seed, status_update, None
483
 
484
  try:
485
  if not prompt.strip():
486
+ logger.error("Empty prompt provided in UI")
487
  status_update = "Error: Prompt cannot be empty"
488
  yield None, seed, status_update, None
489
  return
490
 
491
  # Create request
492
+ logger.info("Creating API request")
493
  task_id, used_seed = create_request(prompt, aspect_ratio, seed)
494
  status_update = f"Request sent. Task ID: {task_id}. Waiting for results..."
495
  yield None, used_seed, status_update, None
 
498
  start_time = time.time()
499
  last_completion_ratio = 0
500
  progress(0, desc="Initializing...")
501
+ logger.info(f"Starting to poll for results for task ID: {task_id}")
502
 
503
  while time.time() - start_time < MAX_POLL_TIME:
504
+ elapsed_time = time.time() - start_time
505
+ logger.debug(f"Polling for results - Task ID: {task_id}, Elapsed time: {elapsed_time:.2f}s")
506
+
507
  result = get_results(task_id)
508
  if not result:
509
+ logger.debug(f"No result yet for task ID: {task_id}, waiting {POLL_INTERVAL}s...")
510
  time.sleep(POLL_INTERVAL)
511
  continue
512
 
513
  sub_results = result.get("result", {}).get("sub_task_results", [])
514
  if not sub_results:
515
+ logger.debug(f"No sub-task results yet for task ID: {task_id}, waiting {POLL_INTERVAL}s...")
516
  time.sleep(POLL_INTERVAL)
517
  continue
518
 
519
  status = sub_results[0].get("task_status")
520
+ logger.debug(f"Task status for ID {task_id}: {status}")
521
 
522
  # Get and display completion ratio
523
  completion_ratio = sub_results[0].get('task_completion', 0) * 100
 
527
  progress_bar = "█" * int(completion_ratio / 10) + "░" * (10 - int(completion_ratio / 10))
528
  status_update = f"Generating image: {completion_ratio}% complete"
529
  progress(completion_ratio / 100, desc=f"Generating image")
530
+ logger.info(f"Generation progress - Task ID: {task_id}, Completion: {completion_ratio:.1f}%")
531
  yield None, used_seed, status_update, None
532
 
533
  # Check task status
534
  if status == 1: # Success
535
+ logger.info(f"Task completed successfully - Task ID: {task_id}")
536
  progress(1.0, desc="Generation complete")
537
  image_name = sub_results[0].get("image")
538
  if not image_name:
539
+ logger.error(f"No image name in successful response. Response: {sub_results[0]}")
540
  status_update = "Error: No image name in successful response"
541
  yield None, used_seed, status_update, None
542
  return
 
545
  yield None, used_seed, status_update, None
546
 
547
  image_url = f"{API_IMAGE_URL}{image_name}.png"
548
+ logger.info(f"Downloading image - Task ID: {task_id}, URL: {image_url}")
549
  image = download_image(image_url)
550
 
551
  if image:
552
  # Add metadata to the image
553
+ logger.info(f"Adding metadata to image - Task ID: {task_id}")
554
  metadata = {
555
  "prompt": prompt,
556
  "seed": str(used_seed),
 
572
  }
573
 
574
  status_update = "Image generated successfully!"
575
+ logger.info(f"Image generation complete - Task ID: {task_id}")
576
  yield image_with_metadata, used_seed, status_update, info
577
  return
578
  else:
579
+ logger.error(f"Failed to download image - Task ID: {task_id}, URL: {image_url}")
580
  status_update = "Error: Failed to download the generated image"
581
  yield None, used_seed, status_update, None
582
  return
583
 
584
  elif status in {3, 4}: # Failed or Canceled
585
  error_msg = sub_results[0].get("task_error", "Unknown error")
586
+ logger.error(f"Task failed - Task ID: {task_id}, Status: {status}, Error: {error_msg}")
587
  status_update = f"Error: Task failed with status {status}: {error_msg}"
588
  yield None, used_seed, status_update, None
589
  return
 
595
 
596
  time.sleep(POLL_INTERVAL)
597
 
598
+ logger.error(f"Timeout waiting for task completion - Task ID: {task_id}, Max time: {MAX_POLL_TIME}s")
599
  status_update = f"Error: Timeout waiting for image generation after {MAX_POLL_TIME} seconds"
600
  yield None, used_seed, status_update, None
601
 
602
  except APIError as e:
603
+ logger.error(f"API Error during generation: {str(e)}")
604
  status_update = f"API Error: {str(e)}"
605
  yield None, seed, status_update, None
606
 
607
  except ValueError as e:
608
+ logger.error(f"Value Error during generation: {str(e)}")
609
  status_update = f"Value Error: {str(e)}"
610
  yield None, seed, status_update, None
611
 
612
  except Exception as e:
613
+ logger.error(f"Unexpected error during image generation: {str(e)}")
614
+ logger.error(f"Full traceback: {traceback.format_exc()}")
615
  status_update = f"Unexpected error: {str(e)}"
616
  yield None, seed, status_update, None
617
 
 
623
  )
624
 
625
  def clear_outputs():
626
+ logger.info("Clearing UI outputs")
627
  return None, -1, "Status: Ready", None
628
 
629
  clear_btn.click(
 
651
  inputs=[prompt, aspect_ratio, seed],
652
  outputs=[output_image, seed_used, status_msg, image_info],
653
  fn=generate_with_status,
654
+ cache_examples=True,
655
+ cache_mode="lazy"
656
  )
657
+
658
+ logger.info("Gradio UI created successfully")
659
  return demo
660
 
661
  # Launch app
662
  if __name__ == "__main__":
663
+ logger.info("Starting HiDream-I1-Dev Image Generator application")
664
  demo = create_ui()
665
+ logger.info("Launching Gradio interface with queue")
666
  demo.queue(max_size=10, default_concurrency_limit=5).launch()
667
+ logger.info("Application shutdown")