Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -23,63 +23,55 @@ import gc
|
|
23 |
from PIL import Image, ImageDraw, ImageFont
|
24 |
|
25 |
def clear_memory():
|
26 |
-
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์
|
27 |
gc.collect()
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
def initialize_models():
|
36 |
-
"""๋ชจ๋ธ ์ด๊ธฐํ ํจ์ - Spaces GPU ํ๊ฒฝ์ ๋ง๊ฒ ์์ """
|
37 |
global segmenter, gd_model, gd_processor, pipe, translator
|
38 |
|
39 |
try:
|
40 |
-
# GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
41 |
clear_memory()
|
42 |
|
43 |
-
# ๋ฒ์ญ
|
44 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name
|
|
|
45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
46 |
-
translator = pipeline("translation", model=model, tokenizer=tokenizer,
|
|
|
47 |
|
48 |
-
# GroundingDINO ๋ชจ๋ธ
|
49 |
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
|
50 |
gd_model = GroundingDinoForObjectDetection.from_pretrained(
|
51 |
gd_model_path,
|
52 |
torch_dtype=torch.float16,
|
53 |
-
device_map='
|
|
|
54 |
)
|
55 |
|
56 |
-
# Segmenter
|
57 |
-
segmenter = BoxSegmenter(device='
|
58 |
|
59 |
-
# FLUX ํ์ดํ๋ผ์ธ
|
60 |
pipe = FluxPipeline.from_pretrained(
|
61 |
"black-forest-labs/FLUX.1-dev",
|
62 |
torch_dtype=torch.float16,
|
63 |
-
|
|
|
64 |
)
|
65 |
-
pipe.enable_attention_slicing(
|
|
|
66 |
|
67 |
-
# LoRA ๊ฐ์ค์น ๋ก๋
|
68 |
-
pipe.load_lora_weights(
|
69 |
-
hf_hub_download(
|
70 |
-
"ByteDance/Hyper-SD",
|
71 |
-
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
|
72 |
-
token=HF_TOKEN
|
73 |
-
)
|
74 |
-
)
|
75 |
-
pipe.fuse_lora(lora_scale=0.125)
|
76 |
-
|
77 |
-
if torch.cuda.is_available():
|
78 |
-
pipe = pipe.to('cuda:0') # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
79 |
-
|
80 |
except Exception as e:
|
81 |
print(f"Model initialization error: {str(e)}")
|
82 |
raise
|
|
|
83 |
# GPU ์ค์
|
84 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
85 |
|
@@ -234,32 +226,29 @@ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int,
|
|
234 |
return base_size * 4 // 3, base_size
|
235 |
return base_size, base_size
|
236 |
|
237 |
-
@spaces.GPU(duration=20)
|
238 |
def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
|
239 |
try:
|
240 |
width, height = calculate_dimensions(aspect_ratio)
|
241 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
242 |
|
243 |
-
|
|
|
244 |
if width > max_size or height > max_size:
|
245 |
ratio = max_size / max(width, height)
|
246 |
width = int(width * ratio)
|
247 |
height = int(height * ratio)
|
248 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
249 |
|
250 |
-
with
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
).images[0]
|
260 |
-
except Exception as e:
|
261 |
-
print(f"Pipeline error: {str(e)}")
|
262 |
-
return Image.new('RGB', (width, height), 'white')
|
263 |
|
264 |
return image
|
265 |
except Exception as e:
|
@@ -460,7 +449,7 @@ def update_box_button(img, box_input):
|
|
460 |
except:
|
461 |
return gr.update(interactive=False, variant="secondary")
|
462 |
|
463 |
-
def process_image(img: Image.Image, max_size: int =
|
464 |
"""์ด๋ฏธ์ง ํฌ๊ธฐ ์ต์ ํ"""
|
465 |
if img.width > max_size or img.height > max_size:
|
466 |
ratio = max_size / max(img.width, img.height)
|
@@ -824,14 +813,15 @@ if __name__ == "__main__":
|
|
824 |
queue=True
|
825 |
)
|
826 |
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
|
|
|
23 |
from PIL import Image, ImageDraw, ImageFont
|
24 |
|
25 |
def clear_memory():
|
26 |
+
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
27 |
gc.collect()
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
|
30 |
+
# ์ฌ์ฉํ์ง ์๋ ์บ์ ์ ๋ฆฌ
|
31 |
+
if hasattr(torch.cuda, 'empty_cache'):
|
32 |
+
torch.cuda.empty_cache()
|
33 |
+
if hasattr(torch.cuda, 'ipc_collect'):
|
34 |
+
torch.cuda.ipc_collect()
|
35 |
|
36 |
def initialize_models():
|
|
|
37 |
global segmenter, gd_model, gd_processor, pipe, translator
|
38 |
|
39 |
try:
|
|
|
40 |
clear_memory()
|
41 |
|
42 |
+
# CPU์์๋ง ์คํ๋๋ ๋ฒ์ญ ๋ชจ๋ธ
|
43 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
|
44 |
+
low_cpu_mem_usage=True).to('cpu')
|
45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
46 |
+
translator = pipeline("translation", model=model, tokenizer=tokenizer,
|
47 |
+
device=-1)
|
48 |
|
49 |
+
# GroundingDINO ๋ชจ๋ธ
|
50 |
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
|
51 |
gd_model = GroundingDinoForObjectDetection.from_pretrained(
|
52 |
gd_model_path,
|
53 |
torch_dtype=torch.float16,
|
54 |
+
device_map='auto',
|
55 |
+
low_cpu_mem_usage=True
|
56 |
)
|
57 |
|
58 |
+
# Segmenter
|
59 |
+
segmenter = BoxSegmenter(device='cpu')
|
60 |
|
61 |
+
# FLUX ํ์ดํ๋ผ์ธ
|
62 |
pipe = FluxPipeline.from_pretrained(
|
63 |
"black-forest-labs/FLUX.1-dev",
|
64 |
torch_dtype=torch.float16,
|
65 |
+
device_map='auto',
|
66 |
+
low_cpu_mem_usage=True
|
67 |
)
|
68 |
+
pipe.enable_attention_slicing()
|
69 |
+
pipe.enable_model_cpu_offload()
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
print(f"Model initialization error: {str(e)}")
|
73 |
raise
|
74 |
+
|
75 |
# GPU ์ค์
|
76 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
77 |
|
|
|
226 |
return base_size * 4 // 3, base_size
|
227 |
return base_size, base_size
|
228 |
|
229 |
+
@spaces.GPU(duration=20)
|
230 |
def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
|
231 |
try:
|
232 |
width, height = calculate_dimensions(aspect_ratio)
|
233 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
234 |
|
235 |
+
# ์ต๋ ํฌ๊ธฐ ์ ํ
|
236 |
+
max_size = 512 # 768์์ 512๋ก ๊ฐ์
|
237 |
if width > max_size or height > max_size:
|
238 |
ratio = max_size / max(width, height)
|
239 |
width = int(width * ratio)
|
240 |
height = int(height * ratio)
|
241 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
242 |
|
243 |
+
with torch.inference_mode():
|
244 |
+
image = pipe(
|
245 |
+
prompt=prompt,
|
246 |
+
width=width,
|
247 |
+
height=height,
|
248 |
+
num_inference_steps=4, # 8์์ 4๋ก ๊ฐ์
|
249 |
+
guidance_scale=4.0,
|
250 |
+
batch_size=1
|
251 |
+
).images[0]
|
|
|
|
|
|
|
|
|
252 |
|
253 |
return image
|
254 |
except Exception as e:
|
|
|
449 |
except:
|
450 |
return gr.update(interactive=False, variant="secondary")
|
451 |
|
452 |
+
def process_image(img: Image.Image, max_size: int = 768) -> Image.Image:
|
453 |
"""์ด๋ฏธ์ง ํฌ๊ธฐ ์ต์ ํ"""
|
454 |
if img.width > max_size or img.height > max_size:
|
455 |
ratio = max_size / max(img.width, img.height)
|
|
|
813 |
queue=True
|
814 |
)
|
815 |
|
816 |
+
demo.launch(
|
817 |
+
server_name="0.0.0.0",
|
818 |
+
server_port=7860,
|
819 |
+
share=False,
|
820 |
+
max_threads=2,
|
821 |
+
enable_queue=True,
|
822 |
+
cache_examples=False,
|
823 |
+
show_error=True,
|
824 |
+
show_tips=False,
|
825 |
+
max_size=1, # ํ ํฌ๊ธฐ ์ ํ
|
826 |
+
memory_limit="48Gi" # ๋ฉ๋ชจ๋ฆฌ ์ ํ ์ค์
|
827 |
+
)
|