alvdansen commited on
Commit
e5e853f
Β·
verified Β·
1 Parent(s): 6def316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -7,6 +7,7 @@ import spaces
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
  from PIL import Image
 
10
 
11
  # Load the JSON data
12
  with open("sdxl_lora.json", "r") as file:
@@ -40,21 +41,33 @@ def get_image(image_data):
40
  local_path = image_data.get('local_path')
41
  hf_url = image_data.get('hf_url')
42
  else:
43
- return None # or a default image path
44
-
45
- try:
46
- return local_path # Return the local path string
47
- except:
 
 
 
 
 
 
 
 
48
  try:
49
  response = requests.get(hf_url)
50
  if response.status_code == 200:
51
- with open(local_path, 'wb') as f:
52
- f.write(response.content)
53
- return local_path # Return the local path string
 
 
 
54
  except Exception as e:
55
- print(f"Failed to load image: {e}")
56
-
57
- return None # or a default image path
 
58
 
59
  @spaces.GPU
60
  def infer(
@@ -146,9 +159,7 @@ with gr.Blocks(css=css) as demo:
146
  with gr.Row():
147
  with gr.Column(scale=2):
148
  gallery = gr.Gallery(
149
- value=[(img, title) for img, title in
150
- ((get_image(item["image"]), item["title"]) for item in sdxl_loras_raw)
151
- if img is not None],
152
  label="SDXL LoRA Gallery",
153
  show_label=False,
154
  elem_id="gallery",
 
7
  import torch
8
  from diffusers import DiffusionPipeline, LCMScheduler
9
  from PIL import Image
10
+ import os
11
 
12
  # Load the JSON data
13
  with open("sdxl_lora.json", "r") as file:
 
41
  local_path = image_data.get('local_path')
42
  hf_url = image_data.get('hf_url')
43
  else:
44
+ print(f"Unexpected image_data format: {type(image_data)}")
45
+ return None
46
+
47
+ # Try loading from local path first
48
+ if local_path and os.path.exists(local_path):
49
+ try:
50
+ Image.open(local_path).verify() # Verify that it's a valid image
51
+ return local_path
52
+ except Exception as e:
53
+ print(f"Error loading local image {local_path}: {e}")
54
+
55
+ # If local path fails or doesn't exist, try URL
56
+ if hf_url:
57
  try:
58
  response = requests.get(hf_url)
59
  if response.status_code == 200:
60
+ img = Image.open(requests.get(hf_url, stream=True).raw)
61
+ img.verify() # Verify that it's a valid image
62
+ img.save(local_path) # Save for future use
63
+ return local_path
64
+ else:
65
+ print(f"Failed to fetch image from URL {hf_url}. Status code: {response.status_code}")
66
  except Exception as e:
67
+ print(f"Error loading image from URL {hf_url}: {e}")
68
+
69
+ print(f"Failed to load image for {image_data}")
70
+ return None
71
 
72
  @spaces.GPU
73
  def infer(
 
159
  with gr.Row():
160
  with gr.Column(scale=2):
161
  gallery = gr.Gallery(
162
+ value=[(get_image(item["image"]), item["title"]) for item in sdxl_loras_raw if get_image(item["image"]) is not None],
 
 
163
  label="SDXL LoRA Gallery",
164
  show_label=False,
165
  elem_id="gallery",