Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -25,7 +25,6 @@ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
|
25 |
import gc
|
26 |
import torch.cuda.amp as amp
|
27 |
|
28 |
-
# ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ํจ์ ๊ฐํ
|
29 |
def clear_memory():
|
30 |
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
31 |
if torch.cuda.is_available():
|
@@ -34,7 +33,11 @@ def clear_memory():
|
|
34 |
gc.collect()
|
35 |
|
36 |
# ์๋ ํผํฉ ์ ๋ฐ๋(Automatic Mixed Precision) ์ค์
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
|
40 |
model_name = "Helsinki-NLP/opus-mt-ko-en"
|
@@ -85,17 +88,16 @@ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_
|
|
85 |
gd_model = gd_model.to(device=device)
|
86 |
assert isinstance(gd_model, GroundingDinoForObjectDetection)
|
87 |
|
|
|
88 |
# FLUX ํ์ดํ๋ผ์ธ ์ด๊ธฐํ
|
89 |
pipe = FluxPipeline.from_pretrained(
|
90 |
"black-forest-labs/FLUX.1-dev",
|
91 |
torch_dtype=torch.float16, # A100์ ์ต์ ํ๋ float16 ์ฌ์ฉ
|
92 |
-
use_auth_token=HF_TOKEN
|
93 |
-
device_map="balanced"
|
94 |
)
|
95 |
pipe.enable_attention_slicing(slice_size="auto") # ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ต์ ํ
|
96 |
-
pipe.enable_sequential_cpu_offload() # CPU ์คํ๋ก๋ฉ ํ์ฑํ
|
97 |
-
|
98 |
|
|
|
99 |
pipe.load_lora_weights(
|
100 |
hf_hub_download(
|
101 |
"ByteDance/Hyper-SD",
|
@@ -105,6 +107,13 @@ pipe.load_lora_weights(
|
|
105 |
)
|
106 |
pipe.fuse_lora(lora_scale=0.125)
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ๋จ์ผ GPU ์ฌ์ฉ
|
109 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # CUDA ๋ฉ๋ชจ๋ฆฌ ํ ๋น ์ค์
|
110 |
|
@@ -195,17 +204,19 @@ def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
|
|
195 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
196 |
|
197 |
with timer("Background generation"):
|
198 |
-
with torch.
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
209 |
except Exception as e:
|
210 |
print(f"Background generation error: {str(e)}")
|
211 |
return Image.new('RGB', (512, 512), 'white')
|
|
|
25 |
import gc
|
26 |
import torch.cuda.amp as amp
|
27 |
|
|
|
28 |
def clear_memory():
|
29 |
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
30 |
if torch.cuda.is_available():
|
|
|
33 |
gc.collect()
|
34 |
|
35 |
# ์๋ ํผํฉ ์ ๋ฐ๋(Automatic Mixed Precision) ์ค์
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
scaler = torch.amp.GradScaler('cuda')
|
38 |
+
else:
|
39 |
+
scaler = None
|
40 |
+
|
41 |
|
42 |
|
43 |
model_name = "Helsinki-NLP/opus-mt-ko-en"
|
|
|
88 |
gd_model = gd_model.to(device=device)
|
89 |
assert isinstance(gd_model, GroundingDinoForObjectDetection)
|
90 |
|
91 |
+
|
92 |
# FLUX ํ์ดํ๋ผ์ธ ์ด๊ธฐํ
|
93 |
pipe = FluxPipeline.from_pretrained(
|
94 |
"black-forest-labs/FLUX.1-dev",
|
95 |
torch_dtype=torch.float16, # A100์ ์ต์ ํ๋ float16 ์ฌ์ฉ
|
96 |
+
use_auth_token=HF_TOKEN
|
|
|
97 |
)
|
98 |
pipe.enable_attention_slicing(slice_size="auto") # ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ต์ ํ
|
|
|
|
|
99 |
|
100 |
+
# LoRA ๊ฐ์ค์น ๋ก๋
|
101 |
pipe.load_lora_weights(
|
102 |
hf_hub_download(
|
103 |
"ByteDance/Hyper-SD",
|
|
|
107 |
)
|
108 |
pipe.fuse_lora(lora_scale=0.125)
|
109 |
|
110 |
+
# GPU ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
pipe.to("cuda")
|
113 |
+
pipe.enable_vae_slicing() # VAE ์ฌ๋ผ์ด์ฑ ํ์ฑํ
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # ๋จ์ผ GPU ์ฌ์ฉ
|
118 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" # CUDA ๋ฉ๋ชจ๋ฆฌ ํ ๋น ์ค์
|
119 |
|
|
|
204 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
205 |
|
206 |
with timer("Background generation"):
|
207 |
+
with torch.inference_mode(): # inference_mode ์ฌ์ฉ
|
208 |
+
with torch.cuda.amp.autocast():
|
209 |
+
image = pipe(
|
210 |
+
prompt=prompt,
|
211 |
+
width=width,
|
212 |
+
height=height,
|
213 |
+
num_inference_steps=8,
|
214 |
+
guidance_scale=4.0,
|
215 |
+
max_length=77,
|
216 |
+
).images[0]
|
217 |
+
|
218 |
+
clear_memory() # ์ฆ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
219 |
+
return image
|
220 |
except Exception as e:
|
221 |
print(f"Background generation error: {str(e)}")
|
222 |
return Image.new('RGB', (512, 512), 'white')
|