fantaxy commited on
Commit
645ebcd
Β·
verified Β·
1 Parent(s): 49c67fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -27
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  import io
4
  import random
5
  import os
 
6
  from PIL import Image
7
  import json
8
 
@@ -248,33 +249,45 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
248
  }
249
  }
250
 
251
- # νƒ€μž„μ•„μ›ƒ 값을 늘리고 μž¬μ‹œλ„ 둜직 μΆ”κ°€
252
  max_retries = 3
253
  current_retry = 0
 
254
 
255
  while current_retry < max_retries:
256
  try:
257
- response = requests.post(API_URL, headers=headers, json=payload, timeout=180) # νƒ€μž„μ•„μ›ƒμ„ 180초둜 증가
258
  response.raise_for_status()
259
 
260
  image = Image.open(io.BytesIO(response.content))
261
  print(f'Generation {key} completed successfully')
262
  return image
263
 
264
- except requests.exceptions.Timeout:
 
265
  current_retry += 1
266
  if current_retry < max_retries:
267
- print(f"Timeout occurred. Retrying... (Attempt {current_retry + 1}/{max_retries})")
 
 
268
  continue
269
  else:
270
- raise gr.Error(f"Request timed out after {max_retries} attempts. The model might be busy, please try again later.")
 
 
 
 
 
 
 
 
 
271
 
272
- except requests.exceptions.RequestException as e:
273
- raise gr.Error(f"Request failed: {str(e)}")
274
 
275
- except requests.exceptions.RequestException as e:
276
- error_message = f"Request failed: {str(e)}"
277
- if hasattr(e, 'response') and e.response is not None:
278
  if e.response.status_code == 401:
279
  error_message = "Invalid API token. Please check your Hugging Face API token."
280
  elif e.response.status_code == 403:
@@ -282,8 +295,6 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
282
  elif e.response.status_code == 503:
283
  error_message = "Model is currently loading. Please try again in a few moments."
284
  raise gr.Error(error_message)
285
- except Exception as e:
286
- raise gr.Error(f"Unexpected error: {str(e)}")
287
 
288
 
289
  def generate_grid(prompt, selected_models, custom_lora, negative_prompt, steps, cfg_scale, seed, strength, width, height, progress=gr.Progress()):
@@ -292,37 +303,61 @@ def generate_grid(prompt, selected_models, custom_lora, negative_prompt, steps,
292
  if len(selected_models) == 0:
293
  raise gr.Error("Please select at least 1 model")
294
 
295
- # 초기 이미지 λ°°μ—΄ 생성
296
  images = [None] * 4
297
  total_models = len(selected_models[:4])
298
 
299
  def update_gallery():
300
- # None이 μ•„λ‹Œ μ΄λ―Έμ§€λ§Œ ν¬ν•¨ν•˜μ—¬ 가러리 μ—…λ°μ΄νŠΈ
301
  return [img for img in images if img is not None]
302
 
303
- # 각 λͺ¨λΈλ³„λ‘œ 이미지 생성
 
 
 
304
  for idx, model_name in enumerate(selected_models[:4]):
305
  try:
306
  progress((idx + 1) / total_models, f"Generating image for {model_name}...")
307
  img = query(prompt, model_name, custom_lora, negative_prompt, steps, cfg_scale, seed, strength, width, height)
308
  images[idx] = img
309
- # 이미지가 생성될 λ•Œλ§ˆλ‹€ 가러리 μ—…λ°μ΄νŠΈ
 
 
 
 
 
310
  yield update_gallery()
311
  except Exception as e:
312
  print(f"Error generating image for {model_name}: {str(e)}")
 
313
  continue
314
 
315
- # 남은 μŠ¬λ‘―μ„ λ§ˆμ§€λ§‰ μƒμ„±λœ μ΄λ―Έμ§€λ‘œ 채움
316
- last_valid_image = next((img for img in reversed(images) if img is not None), None)
317
- if last_valid_image:
318
  for i in range(len(images)):
319
  if images[i] is None:
320
- images[i] = last_valid_image
 
 
 
 
 
 
321
 
322
  progress(1.0, "Generation complete!")
323
  yield update_gallery()
324
 
325
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  css = """
328
  footer {
@@ -374,17 +409,17 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
374
  lines=1
375
  )
376
 
377
- # μƒμœ„ 4개 λͺ¨λΈμ„ 기본으둜 μ„€μ •
378
  default_models = [
379
- "FLUX.1 [Schnell]", # λͺ¨λΈ 이름 톡일
380
  "Stable Diffusion 3.5 Large",
381
  "Stable Diffusion 3.5 Large Turbo",
382
  "Midjourney"
383
  ]
384
 
385
- # 전체 λͺ¨λΈ 리슀트
386
  models_list = [
387
- "FLUX.1 [Schnell]", # λͺ¨λΈ 이름 톡일
388
  "Stable Diffusion 3.5 Large",
389
  "Stable Diffusion 3.5 Large Turbo",
390
  "Stable Diffusion XL",
@@ -428,7 +463,15 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
428
  with gr.Row():
429
  generate_btn = gr.Button("Generate 2x2 Grid", variant="primary", size="lg")
430
 
431
-
 
 
 
 
 
 
 
 
432
 
433
  with gr.Row():
434
  gallery = gr.Gallery(
@@ -438,10 +481,10 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
438
  columns=2,
439
  rows=2,
440
  height="auto",
441
- preview=True, # μ‹€μ‹œκ°„ 프리뷰 ν™œμ„±ν™”
442
  )
443
 
444
- # 이벀트 ν•Έλ“€λŸ¬ μˆ˜μ •
445
  generate_btn.click(
446
  fn=generate_grid,
447
  inputs=[
@@ -465,6 +508,9 @@ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as dalle:
465
  return gr.update(choices=filtered_models, value=[])
466
 
467
  model_search.change(filter_models, inputs=model_search, outputs=model)
 
 
 
468
 
469
  if __name__ == "__main__":
470
  dalle.launch(show_api=False, share=False)
 
3
  import io
4
  import random
5
  import os
6
+ import time
7
  from PIL import Image
8
  import json
9
 
 
249
  }
250
  }
251
 
252
+ # Improved retry logic with exponential backoff
253
  max_retries = 3
254
  current_retry = 0
255
+ backoff_factor = 2 # Exponential backoff
256
 
257
  while current_retry < max_retries:
258
  try:
259
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=180) # 3-minute timeout
260
  response.raise_for_status()
261
 
262
  image = Image.open(io.BytesIO(response.content))
263
  print(f'Generation {key} completed successfully')
264
  return image
265
 
266
+ except (requests.exceptions.Timeout, requests.exceptions.ConnectionError,
267
+ requests.exceptions.HTTPError, requests.exceptions.RequestException) as e:
268
  current_retry += 1
269
  if current_retry < max_retries:
270
+ wait_time = backoff_factor ** current_retry # Exponential backoff
271
+ print(f"Network error occurred: {str(e)}. Retrying in {wait_time} seconds... (Attempt {current_retry + 1}/{max_retries})")
272
+ time.sleep(wait_time) # Add delay before retry
273
  continue
274
  else:
275
+ # Detailed error message based on exception type
276
+ if isinstance(e, requests.exceptions.Timeout):
277
+ error_msg = f"Request timed out after {max_retries} attempts. The model might be busy, please try again later."
278
+ elif isinstance(e, requests.exceptions.ConnectionError):
279
+ error_msg = f"Connection error after {max_retries} attempts. Please check your network connection."
280
+ elif isinstance(e, requests.exceptions.HTTPError):
281
+ status_code = e.response.status_code if hasattr(e, 'response') and e.response is not None else "unknown"
282
+ error_msg = f"HTTP error (status code: {status_code}) after {max_retries} attempts."
283
+ else:
284
+ error_msg = f"Request failed after {max_retries} attempts: {str(e)}"
285
 
286
+ raise gr.Error(error_msg)
 
287
 
288
+ except Exception as e:
289
+ error_message = f"Unexpected error: {str(e)}"
290
+ if isinstance(e, requests.exceptions.RequestException) and hasattr(e, 'response') and e.response is not None:
291
  if e.response.status_code == 401:
292
  error_message = "Invalid API token. Please check your Hugging Face API token."
293
  elif e.response.status_code == 403:
 
295
  elif e.response.status_code == 503:
296
  error_message = "Model is currently loading. Please try again in a few moments."
297
  raise gr.Error(error_message)
 
 
298
 
299
 
300
  def generate_grid(prompt, selected_models, custom_lora, negative_prompt, steps, cfg_scale, seed, strength, width, height, progress=gr.Progress()):
 
303
  if len(selected_models) == 0:
304
  raise gr.Error("Please select at least 1 model")
305
 
306
+ # Initialize image array
307
  images = [None] * 4
308
  total_models = len(selected_models[:4])
309
 
310
  def update_gallery():
311
+ # Only include non-None images for gallery update
312
  return [img for img in images if img is not None]
313
 
314
+ # Create placeholder for failed models
315
+ placeholder_image = None
316
+
317
+ # Generate image for each model
318
  for idx, model_name in enumerate(selected_models[:4]):
319
  try:
320
  progress((idx + 1) / total_models, f"Generating image for {model_name}...")
321
  img = query(prompt, model_name, custom_lora, negative_prompt, steps, cfg_scale, seed, strength, width, height)
322
  images[idx] = img
323
+
324
+ # If this is the first successful generation, save as placeholder for failed models
325
+ if placeholder_image is None:
326
+ placeholder_image = img
327
+
328
+ # Update gallery after each successful generation
329
  yield update_gallery()
330
  except Exception as e:
331
  print(f"Error generating image for {model_name}: {str(e)}")
332
+ # Keep the slot as None and continue with next model
333
  continue
334
 
335
+ # Fill empty slots with a placeholder (either the last successful image or a blank image)
336
+ if placeholder_image:
 
337
  for i in range(len(images)):
338
  if images[i] is None:
339
+ # Create a copy of placeholder to avoid reference issues
340
+ images[i] = placeholder_image.copy()
341
+ else:
342
+ # If all models failed, create a blank image with error text
343
+ for i in range(len(images)):
344
+ blank_img = Image.new('RGB', (width, height), color=(240, 240, 240))
345
+ images[i] = blank_img
346
 
347
  progress(1.0, "Generation complete!")
348
  yield update_gallery()
349
 
350
 
351
+ def check_network_connectivity():
352
+ """Utility function to check network connectivity to the Hugging Face API"""
353
+ try:
354
+ response = requests.get("https://api-inference.huggingface.co", timeout=5)
355
+ if response.status_code == 200:
356
+ return True
357
+ return False
358
+ except:
359
+ return False
360
+
361
 
362
  css = """
363
  footer {
 
409
  lines=1
410
  )
411
 
412
+ # Set top 4 models as default
413
  default_models = [
414
+ "FLUX.1 [Schnell]",
415
  "Stable Diffusion 3.5 Large",
416
  "Stable Diffusion 3.5 Large Turbo",
417
  "Midjourney"
418
  ]
419
 
420
+ # Full model list
421
  models_list = [
422
+ "FLUX.1 [Schnell]",
423
  "Stable Diffusion 3.5 Large",
424
  "Stable Diffusion 3.5 Large Turbo",
425
  "Stable Diffusion XL",
 
463
  with gr.Row():
464
  generate_btn = gr.Button("Generate 2x2 Grid", variant="primary", size="lg")
465
 
466
+ # Add network status indicator
467
+ network_status = gr.Markdown("", elem_id="network_status")
468
+
469
+ # Function to check and update network status
470
+ def update_network_status():
471
+ if check_network_connectivity():
472
+ return "βœ… Connected to Hugging Face API"
473
+ else:
474
+ return "❌ No connection to Hugging Face API. Please check your network."
475
 
476
  with gr.Row():
477
  gallery = gr.Gallery(
 
481
  columns=2,
482
  rows=2,
483
  height="auto",
484
+ preview=True,
485
  )
486
 
487
+ # Event handlers
488
  generate_btn.click(
489
  fn=generate_grid,
490
  inputs=[
 
508
  return gr.update(choices=filtered_models, value=[])
509
 
510
  model_search.change(filter_models, inputs=model_search, outputs=model)
511
+
512
+ # Update network status when the app loads
513
+ dalle.load(fn=update_network_status, outputs=network_status)
514
 
515
  if __name__ == "__main__":
516
  dalle.launch(show_api=False, share=False)