Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,11 @@ import spaces
|
|
2 |
import argparse
|
3 |
import os
|
4 |
import time
|
|
|
5 |
from os import path
|
6 |
import shutil
|
7 |
from datetime import datetime
|
|
|
8 |
from safetensors.torch import load_file
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
import gradio as gr
|
@@ -20,7 +22,9 @@ os.environ["TRANSFORMERS_CACHE"] = cache_path
|
|
20 |
os.environ["HF_HUB_CACHE"] = cache_path
|
21 |
os.environ["HF_HOME"] = cache_path
|
22 |
|
|
|
23 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
24 |
|
25 |
def filter_prompt(prompt):
|
26 |
# ๋ถ์ ์ ํ ํค์๋ ๋ชฉ๋ก
|
@@ -53,17 +57,41 @@ class timer:
|
|
53 |
end = time.time()
|
54 |
print(f"{self.method} took {str(round(end - self.start, 2))}s")
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
os.makedirs(cache_path, exist_ok=True)
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
pipe
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
css = """
|
69 |
footer {display: none !important}
|
@@ -106,6 +134,20 @@ footer {display: none !important}
|
|
106 |
width: 100% !important;
|
107 |
max-width: 100% !important;
|
108 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
"""
|
110 |
|
111 |
# Create Gradio interface
|
@@ -119,6 +161,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
119 |
</div>
|
120 |
""")
|
121 |
|
|
|
|
|
|
|
|
|
122 |
with gr.Row():
|
123 |
with gr.Column(scale=3):
|
124 |
prompt = gr.Textbox(
|
@@ -161,7 +207,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
161 |
)
|
162 |
|
163 |
def get_random_seed():
|
164 |
-
return torch.randint(0, 1000000, (1,)).item()
|
165 |
|
166 |
seed = gr.Number(
|
167 |
label="Seed (random by default, set for reproducibility)",
|
@@ -211,34 +257,92 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
211 |
|
212 |
@spaces.GPU
|
213 |
def process_image(height, width, steps, scales, prompt, seed):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# ํ๋กฌํํธ ํํฐ๋ง
|
215 |
is_safe, filtered_prompt = filter_prompt(prompt)
|
216 |
if not is_safe:
|
217 |
-
|
218 |
return None
|
219 |
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
generated_image = pipe(
|
223 |
prompt=[filtered_prompt],
|
224 |
-
generator=
|
225 |
-
num_inference_steps=
|
226 |
-
guidance_scale=
|
227 |
-
height=
|
228 |
-
width=
|
229 |
max_sequence_length=256
|
230 |
).images[0]
|
231 |
|
|
|
232 |
return generated_image
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
def update_seed():
|
238 |
return get_random_seed()
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
generate_btn.click(
|
241 |
-
|
242 |
inputs=[height, width, steps, scales, prompt, seed],
|
243 |
outputs=[output]
|
244 |
)
|
@@ -247,11 +351,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
247 |
update_seed,
|
248 |
outputs=[seed]
|
249 |
)
|
250 |
-
|
251 |
-
generate_btn.click(
|
252 |
-
update_seed,
|
253 |
-
outputs=[seed]
|
254 |
-
)
|
255 |
|
256 |
if __name__ == "__main__":
|
257 |
-
|
|
|
|
2 |
import argparse
|
3 |
import os
|
4 |
import time
|
5 |
+
import gc
|
6 |
from os import path
|
7 |
import shutil
|
8 |
from datetime import datetime
|
9 |
+
import traceback
|
10 |
from safetensors.torch import load_file
|
11 |
from huggingface_hub import hf_hub_download
|
12 |
import gradio as gr
|
|
|
22 |
os.environ["HF_HUB_CACHE"] = cache_path
|
23 |
os.environ["HF_HOME"] = cache_path
|
24 |
|
25 |
+
# GPU ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ต์ ํ
|
26 |
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
torch.backends.cudnn.benchmark = True # ๋ฐ๋ณต์ ์ธ ํฌ๊ธฐ์ ์
๋ ฅ์ ๋ํด ์ฑ๋ฅ ํฅ์
|
28 |
|
29 |
def filter_prompt(prompt):
|
30 |
# ๋ถ์ ์ ํ ํค์๋ ๋ชฉ๋ก
|
|
|
57 |
end = time.time()
|
58 |
print(f"{self.method} took {str(round(end - self.start, 2))}s")
|
59 |
|
60 |
+
# ๊ธ๋ก๋ฒ ๋ณ์๋ก ํ์ดํ๋ผ์ธ ์ ์ธ
|
61 |
+
pipe = None
|
|
|
62 |
|
63 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ ํจ์ (์ง์ฐ ๋ก๋ฉ)
|
64 |
+
def initialize_model():
|
65 |
+
global pipe
|
66 |
+
|
67 |
+
# ์ด๋ฏธ ๋ก๋๋ ๊ฒฝ์ฐ ๋ค์ ๋ก๋ํ์ง ์์
|
68 |
+
if pipe is not None:
|
69 |
+
return
|
70 |
+
|
71 |
+
try:
|
72 |
+
if not path.exists(cache_path):
|
73 |
+
os.makedirs(cache_path, exist_ok=True)
|
74 |
+
|
75 |
+
# ๋ฉ๋ชจ๋ฆฌ ํ๋ณด๋ฅผ ์ํ ๊ฐ๋น์ง ์ปฌ๋ ์
์คํ
|
76 |
+
gc.collect()
|
77 |
+
torch.cuda.empty_cache()
|
78 |
+
|
79 |
+
with timer("๋ชจ๋ธ ๋ก๋ฉ"):
|
80 |
+
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
81 |
+
lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
82 |
+
pipe.load_lora_weights(lora_path)
|
83 |
+
pipe.fuse_lora(lora_scale=0.125)
|
84 |
+
pipe.to(device="cuda", dtype=torch.bfloat16)
|
85 |
+
|
86 |
+
# ์์ ๊ฒ์ฌ๊ธฐ ์ถ๊ฐ
|
87 |
+
pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
88 |
+
|
89 |
+
print("๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ")
|
90 |
+
return True
|
91 |
+
except Exception as e:
|
92 |
+
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
93 |
+
traceback.print_exc()
|
94 |
+
return False
|
95 |
|
96 |
css = """
|
97 |
footer {display: none !important}
|
|
|
134 |
width: 100% !important;
|
135 |
max-width: 100% !important;
|
136 |
}
|
137 |
+
.loading-indicator {
|
138 |
+
text-align: center;
|
139 |
+
padding: 20px;
|
140 |
+
font-weight: bold;
|
141 |
+
color: #4B79A1;
|
142 |
+
}
|
143 |
+
.error-message {
|
144 |
+
background-color: rgba(255, 0, 0, 0.1);
|
145 |
+
color: red;
|
146 |
+
padding: 10px;
|
147 |
+
border-radius: 8px;
|
148 |
+
margin-top: 10px;
|
149 |
+
text-align: center;
|
150 |
+
}
|
151 |
"""
|
152 |
|
153 |
# Create Gradio interface
|
|
|
161 |
</div>
|
162 |
""")
|
163 |
|
164 |
+
# ์ํ ํ์ ๋ณ์
|
165 |
+
error_message = gr.HTML(visible=False, elem_classes=["error-message"])
|
166 |
+
loading_status = gr.HTML(visible=False, elem_classes=["loading-indicator"])
|
167 |
+
|
168 |
with gr.Row():
|
169 |
with gr.Column(scale=3):
|
170 |
prompt = gr.Textbox(
|
|
|
207 |
)
|
208 |
|
209 |
def get_random_seed():
|
210 |
+
return int(torch.randint(0, 1000000, (1,)).item())
|
211 |
|
212 |
seed = gr.Number(
|
213 |
label="Seed (random by default, set for reproducibility)",
|
|
|
257 |
|
258 |
@spaces.GPU
|
259 |
def process_image(height, width, steps, scales, prompt, seed):
|
260 |
+
# ๋ชจ๋ธ ์ด๊ธฐํ ์ํ ํ์ธ
|
261 |
+
if pipe is None:
|
262 |
+
loading_status.update("๋ชจ๋ธ์ ๋ก๋ฉ ์ค์
๋๋ค... ์ฒ์ ์คํ ์ ์๊ฐ์ด ์์๋ ์ ์์ต๋๋ค.", visible=True)
|
263 |
+
|
264 |
+
model_loaded = initialize_model()
|
265 |
+
if not model_loaded:
|
266 |
+
error_message.update("๋ชจ๋ธ ๋ก๋ฉ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค. ํ์ด์ง๋ฅผ ์๋ก๊ณ ์นจํ๊ณ ๋ค์ ์๋ํด ์ฃผ์ธ์.", visible=True)
|
267 |
+
loading_status.update(visible=False)
|
268 |
+
return None
|
269 |
+
|
270 |
+
loading_status.update(visible=False)
|
271 |
+
|
272 |
+
# ์
๋ ฅ๊ฐ ๊ฒ์ฆ
|
273 |
+
if not prompt or prompt.strip() == "":
|
274 |
+
error_message.update("์ด๋ฏธ์ง ์ค๋ช
์ ์
๋ ฅํด์ฃผ์ธ์.", visible=True)
|
275 |
+
return None
|
276 |
+
|
277 |
# ํ๋กฌํํธ ํํฐ๋ง
|
278 |
is_safe, filtered_prompt = filter_prompt(prompt)
|
279 |
if not is_safe:
|
280 |
+
error_message.update("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์
๋๋ค.", visible=True)
|
281 |
return None
|
282 |
|
283 |
+
# ์๋ฌ ๋ฉ์์ง ์ด๊ธฐํ
|
284 |
+
error_message.update(visible=False)
|
285 |
+
loading_status.update("์ด๋ฏธ์ง๋ฅผ ์์ฑ ์ค์
๋๋ค...", visible=True)
|
286 |
+
|
287 |
+
try:
|
288 |
+
# ๋ฉ๋ชจ๋ฆฌ ํ๋ณด๋ฅผ ์ํ ๊ฐ๋น์ง ์ฝ๋ ์
|
289 |
+
gc.collect()
|
290 |
+
torch.cuda.empty_cache()
|
291 |
+
|
292 |
+
# ์๋ ๊ฐ ํ์ธ ๋ฐ ๋ณด์
|
293 |
+
if seed is None or not isinstance(seed, (int, float)):
|
294 |
+
seed = get_random_seed()
|
295 |
+
else:
|
296 |
+
seed = int(seed) # ํ์
๋ณํ ์์ ํ๊ฒ ์ฒ๋ฆฌ
|
297 |
+
|
298 |
+
# ์ด๋ฏธ์ง ์์ฑ
|
299 |
+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
300 |
+
generator = torch.Generator(device="cuda").manual_seed(seed)
|
301 |
+
|
302 |
+
# ๋์ด์ ๋๋น๋ฅผ 64์ ๋ฐฐ์๋ก ์กฐ์ (FLUX ๋ชจ๋ธ ์๊ตฌ์ฌํญ)
|
303 |
+
height = (int(height) // 64) * 64
|
304 |
+
width = (int(width) // 64) * 64
|
305 |
+
|
306 |
+
# ์์ ์ฅ์น - ์ต๋๊ฐ ์ ํ
|
307 |
+
steps = min(int(steps), 25)
|
308 |
+
scales = max(min(float(scales), 5.0), 0.0)
|
309 |
+
|
310 |
generated_image = pipe(
|
311 |
prompt=[filtered_prompt],
|
312 |
+
generator=generator,
|
313 |
+
num_inference_steps=steps,
|
314 |
+
guidance_scale=scales,
|
315 |
+
height=height,
|
316 |
+
width=width,
|
317 |
max_sequence_length=256
|
318 |
).images[0]
|
319 |
|
320 |
+
loading_status.update(visible=False)
|
321 |
return generated_image
|
322 |
+
|
323 |
+
except Exception as e:
|
324 |
+
error_msg = f"์ด๋ฏธ์ง ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}"
|
325 |
+
print(error_msg)
|
326 |
+
traceback.print_exc()
|
327 |
+
error_message.update(error_msg, visible=True)
|
328 |
+
loading_status.update(visible=False)
|
329 |
+
|
330 |
+
# ์ค๋ฅ ํ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
331 |
+
gc.collect()
|
332 |
+
torch.cuda.empty_cache()
|
333 |
+
|
334 |
+
return None
|
335 |
|
336 |
def update_seed():
|
337 |
return get_random_seed()
|
338 |
+
|
339 |
+
# ๋ฒํผ ํด๋ฆญ ์ด๋ฒคํธ - ๋ชจ๋ UI ์์ ์ด๊ธฐํ ์ถ๊ฐ
|
340 |
+
def on_generate_click(height, width, steps, scales, prompt, seed):
|
341 |
+
error_message.update(visible=False)
|
342 |
+
return process_image(height, width, steps, scales, prompt, seed)
|
343 |
|
344 |
generate_btn.click(
|
345 |
+
on_generate_click,
|
346 |
inputs=[height, width, steps, scales, prompt, seed],
|
347 |
outputs=[output]
|
348 |
)
|
|
|
351 |
update_seed,
|
352 |
outputs=[seed]
|
353 |
)
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
if __name__ == "__main__":
|
356 |
+
# ์ฑ ์์ ์ ๋ชจ๋ธ ๋ฏธ๋ฆฌ ๋ก๋ํ์ง ์์ (์ฒซ ์์ฒญ ์ ์ง์ฐ ๋ก๋ฉ)
|
357 |
+
demo.queue(concurrency_count=1, max_size=10).launch()
|