fantos commited on
Commit
ba3c0ae
ยท
verified ยท
1 Parent(s): f747c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -138
app.py CHANGED
@@ -15,9 +15,6 @@ from PIL import Image
15
 
16
  # Setup and initialization code
17
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
18
- # Use PERSISTENT_DIR environment variable for Spaces
19
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
20
- gallery_path = path.join(PERSISTENT_DIR, "gallery")
21
 
22
  os.environ["TRANSFORMERS_CACHE"] = cache_path
23
  os.environ["HF_HUB_CACHE"] = cache_path
@@ -25,10 +22,6 @@ os.environ["HF_HOME"] = cache_path
25
 
26
  torch.backends.cuda.matmul.allow_tf32 = True
27
 
28
- # Create gallery directory if it doesn't exist
29
- if not path.exists(gallery_path):
30
- os.makedirs(gallery_path, exist_ok=True)
31
-
32
  def filter_prompt(prompt):
33
  # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก
34
  inappropriate_keywords = [
@@ -101,45 +94,6 @@ footer {display: none !important}
101
  -webkit-background-clip: text;
102
  -webkit-text-fill-color: transparent;
103
  }
104
- #gallery {
105
- width: 100% !important;
106
- max-width: 100% !important;
107
- overflow: visible !important;
108
- }
109
- #gallery > div {
110
- width: 100% !important;
111
- max-width: none !important;
112
- }
113
- #gallery > div > div {
114
- width: 100% !important;
115
- display: grid !important;
116
- grid-template-columns: repeat(5, 1fr) !important;
117
- gap: 16px !important;
118
- padding: 16px !important;
119
- }
120
- .gallery-container {
121
- background: rgba(255, 255, 255, 0.05);
122
- border-radius: 8px;
123
- margin-top: 10px;
124
- width: 100% !important;
125
- box-sizing: border-box !important;
126
- }
127
- .gallery-item {
128
- width: 100% !important;
129
- aspect-ratio: 1 !important;
130
- overflow: hidden !important;
131
- border-radius: 4px !important;
132
- }
133
- .gallery-item img {
134
- width: 100% !important;
135
- height: 100% !important;
136
- object-fit: cover !important;
137
- border-radius: 4px !important;
138
- transition: transform 0.2s;
139
- }
140
- .gallery-item img:hover {
141
- transform: scale(1.05);
142
- }
143
  .output-image {
144
  width: 100% !important;
145
  max-width: 100% !important;
@@ -152,76 +106,8 @@ footer {display: none !important}
152
  width: 100% !important;
153
  max-width: 100% !important;
154
  }
155
- .gallery-container::-webkit-scrollbar {
156
- display: none !important;
157
- }
158
- .gallery-container {
159
- -ms-overflow-style: none !important;
160
- scrollbar-width: none !important;
161
- }
162
- #gallery > div {
163
- width: 100% !important;
164
- max-width: 100% !important;
165
- }
166
- #gallery > div > div {
167
- width: 100% !important;
168
- max-width: 100% !important;
169
- }
170
  """
171
 
172
- def save_image(image):
173
- """Save the generated image and return the path"""
174
- try:
175
- if not os.path.exists(gallery_path):
176
- try:
177
- os.makedirs(gallery_path, exist_ok=True)
178
- except Exception as e:
179
- print(f"Failed to create gallery directory: {str(e)}")
180
- return None
181
-
182
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
183
- random_suffix = os.urandom(4).hex()
184
- filename = f"generated_{timestamp}_{random_suffix}.png"
185
- filepath = os.path.join(gallery_path, filename)
186
-
187
- try:
188
- if isinstance(image, Image.Image):
189
- image.save(filepath, "PNG", quality=100)
190
- else:
191
- image = Image.fromarray(image)
192
- image.save(filepath, "PNG", quality=100)
193
-
194
- if not os.path.exists(filepath):
195
- print(f"Warning: Failed to verify saved image at {filepath}")
196
- return None
197
-
198
- return filepath
199
- except Exception as e:
200
- print(f"Failed to save image: {str(e)}")
201
- return None
202
-
203
- except Exception as e:
204
- print(f"Error in save_image: {str(e)}")
205
- return None
206
-
207
- def load_gallery():
208
- """Load all images from the gallery directory"""
209
- try:
210
- os.makedirs(gallery_path, exist_ok=True)
211
-
212
- image_files = []
213
- for f in os.listdir(gallery_path):
214
- if f.lower().endswith(('.png', '.jpg', '.jpeg')):
215
- full_path = os.path.join(gallery_path, f)
216
- image_files.append((full_path, os.path.getmtime(full_path)))
217
-
218
- image_files.sort(key=lambda x: x[1], reverse=True)
219
-
220
- return [f[0] for f in image_files]
221
- except Exception as e:
222
- print(f"Error loading gallery: {str(e)}")
223
- return []
224
-
225
  # Create Gradio interface
226
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
227
  gr.HTML('<div class="title">AI Image Generator</div>')
@@ -322,27 +208,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
322
  elem_id="output-image",
323
  elem_classes=["output-image", "fixed-width"]
324
  )
325
-
326
- gallery = gr.Gallery(
327
- label="Generated Images Gallery",
328
- show_label=True,
329
- elem_id="gallery",
330
- columns=[4],
331
- rows=[2],
332
- height="auto",
333
- object_fit="cover",
334
- elem_classes=["gallery-container", "fixed-width"]
335
- )
336
-
337
- gallery.value = load_gallery()
338
 
339
  @spaces.GPU
340
- def process_and_save_image(height, width, steps, scales, prompt, seed):
341
  # ํ”„๋กฌํ”„ํŠธ ํ•„ํ„ฐ๋ง
342
  is_safe, filtered_prompt = filter_prompt(prompt)
343
  if not is_safe:
344
  gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
345
- return None, load_gallery()
346
 
347
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
348
  try:
@@ -356,22 +229,18 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
356
  max_sequence_length=256
357
  ).images[0]
358
 
359
- saved_path = save_image(generated_image)
360
- if saved_path is None:
361
- print("Warning: Failed to save generated image")
362
-
363
- return generated_image, load_gallery()
364
  except Exception as e:
365
  print(f"Error in image generation: {str(e)}")
366
- return None, load_gallery()
367
 
368
  def update_seed():
369
  return get_random_seed()
370
 
371
  generate_btn.click(
372
- process_and_save_image,
373
  inputs=[height, width, steps, scales, prompt, seed],
374
- outputs=[output, gallery]
375
  )
376
 
377
  randomize_seed.click(
@@ -385,4 +254,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
385
  )
386
 
387
  if __name__ == "__main__":
388
- demo.launch(allowed_paths=[PERSISTENT_DIR])
 
15
 
16
  # Setup and initialization code
17
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
 
 
 
18
 
19
  os.environ["TRANSFORMERS_CACHE"] = cache_path
20
  os.environ["HF_HUB_CACHE"] = cache_path
 
22
 
23
  torch.backends.cuda.matmul.allow_tf32 = True
24
 
 
 
 
 
25
  def filter_prompt(prompt):
26
  # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก
27
  inappropriate_keywords = [
 
94
  -webkit-background-clip: text;
95
  -webkit-text-fill-color: transparent;
96
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  .output-image {
98
  width: 100% !important;
99
  max-width: 100% !important;
 
106
  width: 100% !important;
107
  max-width: 100% !important;
108
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  """
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # Create Gradio interface
112
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
113
  gr.HTML('<div class="title">AI Image Generator</div>')
 
208
  elem_id="output-image",
209
  elem_classes=["output-image", "fixed-width"]
210
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  @spaces.GPU
213
+ def process_image(height, width, steps, scales, prompt, seed):
214
  # ํ”„๋กฌํ”„ํŠธ ํ•„ํ„ฐ๋ง
215
  is_safe, filtered_prompt = filter_prompt(prompt)
216
  if not is_safe:
217
  gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
218
+ return None
219
 
220
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
221
  try:
 
229
  max_sequence_length=256
230
  ).images[0]
231
 
232
+ return generated_image
 
 
 
 
233
  except Exception as e:
234
  print(f"Error in image generation: {str(e)}")
235
+ return None
236
 
237
  def update_seed():
238
  return get_random_seed()
239
 
240
  generate_btn.click(
241
+ process_image,
242
  inputs=[height, width, steps, scales, prompt, seed],
243
+ outputs=[output]
244
  )
245
 
246
  randomize_seed.click(
 
254
  )
255
 
256
  if __name__ == "__main__":
257
+ demo.launch()