fantos commited on
Commit
d5f9b62
ยท
verified ยท
1 Parent(s): ba3c0ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -29
app.py CHANGED
@@ -2,9 +2,11 @@ import spaces
2
  import argparse
3
  import os
4
  import time
 
5
  from os import path
6
  import shutil
7
  from datetime import datetime
 
8
  from safetensors.torch import load_file
9
  from huggingface_hub import hf_hub_download
10
  import gradio as gr
@@ -20,7 +22,9 @@ os.environ["TRANSFORMERS_CACHE"] = cache_path
20
  os.environ["HF_HUB_CACHE"] = cache_path
21
  os.environ["HF_HOME"] = cache_path
22
 
 
23
  torch.backends.cuda.matmul.allow_tf32 = True
 
24
 
25
  def filter_prompt(prompt):
26
  # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก
@@ -53,17 +57,41 @@ class timer:
53
  end = time.time()
54
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
55
 
56
- # Model initialization
57
- if not path.exists(cache_path):
58
- os.makedirs(cache_path, exist_ok=True)
59
 
60
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
61
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
62
- pipe.fuse_lora(lora_scale=0.125)
63
- pipe.to(device="cuda", dtype=torch.bfloat16)
64
-
65
- # Add safety checker
66
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  css = """
69
  footer {display: none !important}
@@ -106,6 +134,20 @@ footer {display: none !important}
106
  width: 100% !important;
107
  max-width: 100% !important;
108
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  """
110
 
111
  # Create Gradio interface
@@ -119,6 +161,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
119
  </div>
120
  """)
121
 
 
 
 
 
122
  with gr.Row():
123
  with gr.Column(scale=3):
124
  prompt = gr.Textbox(
@@ -161,7 +207,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
161
  )
162
 
163
  def get_random_seed():
164
- return torch.randint(0, 1000000, (1,)).item()
165
 
166
  seed = gr.Number(
167
  label="Seed (random by default, set for reproducibility)",
@@ -211,34 +257,92 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
211
 
212
  @spaces.GPU
213
  def process_image(height, width, steps, scales, prompt, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # ํ”„๋กฌํ”„ํŠธ ํ•„ํ„ฐ๋ง
215
  is_safe, filtered_prompt = filter_prompt(prompt)
216
  if not is_safe:
217
- gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ๏ฟฝ๏ฟฝํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
218
  return None
219
 
220
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
221
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  generated_image = pipe(
223
  prompt=[filtered_prompt],
224
- generator=torch.Generator().manual_seed(int(seed)),
225
- num_inference_steps=int(steps),
226
- guidance_scale=float(scales),
227
- height=int(height),
228
- width=int(width),
229
  max_sequence_length=256
230
  ).images[0]
231
 
 
232
  return generated_image
233
- except Exception as e:
234
- print(f"Error in image generation: {str(e)}")
235
- return None
 
 
 
 
 
 
 
 
 
 
236
 
237
  def update_seed():
238
  return get_random_seed()
 
 
 
 
 
239
 
240
  generate_btn.click(
241
- process_image,
242
  inputs=[height, width, steps, scales, prompt, seed],
243
  outputs=[output]
244
  )
@@ -247,11 +351,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
247
  update_seed,
248
  outputs=[seed]
249
  )
250
-
251
- generate_btn.click(
252
- update_seed,
253
- outputs=[seed]
254
- )
255
 
256
  if __name__ == "__main__":
257
- demo.launch()
 
 
2
  import argparse
3
  import os
4
  import time
5
+ import gc
6
  from os import path
7
  import shutil
8
  from datetime import datetime
9
+ import traceback
10
  from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
  import gradio as gr
 
22
  os.environ["HF_HUB_CACHE"] = cache_path
23
  os.environ["HF_HOME"] = cache_path
24
 
25
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ์„ค์ • ์ตœ์ ํ™”
26
  torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.benchmark = True # ๋ฐ˜๋ณต์ ์ธ ํฌ๊ธฐ์˜ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์„ฑ๋Šฅ ํ–ฅ์ƒ
28
 
29
  def filter_prompt(prompt):
30
  # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก
 
57
  end = time.time()
58
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
59
 
60
+ # ๊ธ€๋กœ๋ฒŒ ๋ณ€์ˆ˜๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์„ ์–ธ
61
+ pipe = None
 
62
 
63
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ (์ง€์—ฐ ๋กœ๋”ฉ)
64
+ def initialize_model():
65
+ global pipe
66
+
67
+ # ์ด๋ฏธ ๋กœ๋“œ๋œ ๊ฒฝ์šฐ ๋‹ค์‹œ ๋กœ๋“œํ•˜์ง€ ์•Š์Œ
68
+ if pipe is not None:
69
+ return
70
+
71
+ try:
72
+ if not path.exists(cache_path):
73
+ os.makedirs(cache_path, exist_ok=True)
74
+
75
+ # ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด๋ฅผ ์œ„ํ•œ ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์‹คํ–‰
76
+ gc.collect()
77
+ torch.cuda.empty_cache()
78
+
79
+ with timer("๋ชจ๋ธ ๋กœ๋”ฉ"):
80
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
81
+ lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
82
+ pipe.load_lora_weights(lora_path)
83
+ pipe.fuse_lora(lora_scale=0.125)
84
+ pipe.to(device="cuda", dtype=torch.bfloat16)
85
+
86
+ # ์•ˆ์ „ ๊ฒ€์‚ฌ๊ธฐ ์ถ”๊ฐ€
87
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
88
+
89
+ print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
90
+ return True
91
+ except Exception as e:
92
+ print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
93
+ traceback.print_exc()
94
+ return False
95
 
96
  css = """
97
  footer {display: none !important}
 
134
  width: 100% !important;
135
  max-width: 100% !important;
136
  }
137
+ .loading-indicator {
138
+ text-align: center;
139
+ padding: 20px;
140
+ font-weight: bold;
141
+ color: #4B79A1;
142
+ }
143
+ .error-message {
144
+ background-color: rgba(255, 0, 0, 0.1);
145
+ color: red;
146
+ padding: 10px;
147
+ border-radius: 8px;
148
+ margin-top: 10px;
149
+ text-align: center;
150
+ }
151
  """
152
 
153
  # Create Gradio interface
 
161
  </div>
162
  """)
163
 
164
+ # ์ƒํƒœ ํ‘œ์‹œ ๋ณ€์ˆ˜
165
+ error_message = gr.HTML(visible=False, elem_classes=["error-message"])
166
+ loading_status = gr.HTML(visible=False, elem_classes=["loading-indicator"])
167
+
168
  with gr.Row():
169
  with gr.Column(scale=3):
170
  prompt = gr.Textbox(
 
207
  )
208
 
209
  def get_random_seed():
210
+ return int(torch.randint(0, 1000000, (1,)).item())
211
 
212
  seed = gr.Number(
213
  label="Seed (random by default, set for reproducibility)",
 
257
 
258
  @spaces.GPU
259
  def process_image(height, width, steps, scales, prompt, seed):
260
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ์ƒํƒœ ํ™•์ธ
261
+ if pipe is None:
262
+ loading_status.update("๋ชจ๋ธ์„ ๋กœ๋”ฉ ์ค‘์ž…๋‹ˆ๋‹ค... ์ฒ˜์Œ ์‹คํ–‰ ์‹œ ์‹œ๊ฐ„์ด ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.", visible=True)
263
+
264
+ model_loaded = initialize_model()
265
+ if not model_loaded:
266
+ error_message.update("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ํŽ˜์ด์ง€๋ฅผ ์ƒˆ๋กœ๊ณ ์นจํ•˜๊ณ  ๋‹ค์‹œ ์‹œ๋„ํ•ด ์ฃผ์„ธ์š”.", visible=True)
267
+ loading_status.update(visible=False)
268
+ return None
269
+
270
+ loading_status.update(visible=False)
271
+
272
+ # ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
273
+ if not prompt or prompt.strip() == "":
274
+ error_message.update("์ด๋ฏธ์ง€ ์„ค๋ช…์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.", visible=True)
275
+ return None
276
+
277
  # ํ”„๋กฌํ”„ํŠธ ํ•„ํ„ฐ๋ง
278
  is_safe, filtered_prompt = filter_prompt(prompt)
279
  if not is_safe:
280
+ error_message.update("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.", visible=True)
281
  return None
282
 
283
+ # ์—๋Ÿฌ ๋ฉ”์‹œ์ง€ ์ดˆ๊ธฐํ™”
284
+ error_message.update(visible=False)
285
+ loading_status.update("์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ ์ค‘์ž…๋‹ˆ๋‹ค...", visible=True)
286
+
287
+ try:
288
+ # ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด๋ฅผ ์œ„ํ•œ ๊ฐ€๋น„์ง€ ์ฝœ๋ ‰์…˜
289
+ gc.collect()
290
+ torch.cuda.empty_cache()
291
+
292
+ # ์‹œ๋“œ ๊ฐ’ ํ™•์ธ ๋ฐ ๋ณด์ •
293
+ if seed is None or not isinstance(seed, (int, float)):
294
+ seed = get_random_seed()
295
+ else:
296
+ seed = int(seed) # ํƒ€์ž… ๋ณ€ํ™˜ ์•ˆ์ „ํ•˜๊ฒŒ ์ฒ˜๋ฆฌ
297
+
298
+ # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
299
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
300
+ generator = torch.Generator(device="cuda").manual_seed(seed)
301
+
302
+ # ๋†’์ด์™€ ๋„ˆ๋น„๋ฅผ 64์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ • (FLUX ๋ชจ๋ธ ์š”๊ตฌ์‚ฌํ•ญ)
303
+ height = (int(height) // 64) * 64
304
+ width = (int(width) // 64) * 64
305
+
306
+ # ์•ˆ์ „์žฅ์น˜ - ์ตœ๋Œ€๊ฐ’ ์ œํ•œ
307
+ steps = min(int(steps), 25)
308
+ scales = max(min(float(scales), 5.0), 0.0)
309
+
310
  generated_image = pipe(
311
  prompt=[filtered_prompt],
312
+ generator=generator,
313
+ num_inference_steps=steps,
314
+ guidance_scale=scales,
315
+ height=height,
316
+ width=width,
317
  max_sequence_length=256
318
  ).images[0]
319
 
320
+ loading_status.update(visible=False)
321
  return generated_image
322
+
323
+ except Exception as e:
324
+ error_msg = f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
325
+ print(error_msg)
326
+ traceback.print_exc()
327
+ error_message.update(error_msg, visible=True)
328
+ loading_status.update(visible=False)
329
+
330
+ # ์˜ค๋ฅ˜ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
331
+ gc.collect()
332
+ torch.cuda.empty_cache()
333
+
334
+ return None
335
 
336
  def update_seed():
337
  return get_random_seed()
338
+
339
+ # ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ - ๋ชจ๋“  UI ์š”์†Œ ์ดˆ๊ธฐํ™” ์ถ”๊ฐ€
340
+ def on_generate_click(height, width, steps, scales, prompt, seed):
341
+ error_message.update(visible=False)
342
+ return process_image(height, width, steps, scales, prompt, seed)
343
 
344
  generate_btn.click(
345
+ on_generate_click,
346
  inputs=[height, width, steps, scales, prompt, seed],
347
  outputs=[output]
348
  )
 
351
  update_seed,
352
  outputs=[seed]
353
  )
 
 
 
 
 
354
 
355
  if __name__ == "__main__":
356
+ # ์•ฑ ์‹œ์ž‘ ์‹œ ๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋กœ๋“œํ•˜์ง€ ์•Š์Œ (์ฒซ ์š”์ฒญ ์‹œ ์ง€์—ฐ ๋กœ๋”ฉ)
357
+ demo.queue(concurrency_count=1, max_size=10).launch()