Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,63 +23,55 @@ import gc
|
|
| 23 |
from PIL import Image, ImageDraw, ImageFont
|
| 24 |
|
| 25 |
def clear_memory():
|
| 26 |
-
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์
|
| 27 |
gc.collect()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
def initialize_models():
|
| 36 |
-
"""๋ชจ๋ธ ์ด๊ธฐํ ํจ์ - Spaces GPU ํ๊ฒฝ์ ๋ง๊ฒ ์์ """
|
| 37 |
global segmenter, gd_model, gd_processor, pipe, translator
|
| 38 |
|
| 39 |
try:
|
| 40 |
-
# GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
|
| 41 |
clear_memory()
|
| 42 |
|
| 43 |
-
# ๋ฒ์ญ
|
| 44 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name
|
|
|
|
| 45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 46 |
-
translator = pipeline("translation", model=model, tokenizer=tokenizer,
|
|
|
|
| 47 |
|
| 48 |
-
# GroundingDINO ๋ชจ๋ธ
|
| 49 |
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
|
| 50 |
gd_model = GroundingDinoForObjectDetection.from_pretrained(
|
| 51 |
gd_model_path,
|
| 52 |
torch_dtype=torch.float16,
|
| 53 |
-
device_map='
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
-
# Segmenter
|
| 57 |
-
segmenter = BoxSegmenter(device='
|
| 58 |
|
| 59 |
-
# FLUX ํ์ดํ๋ผ์ธ
|
| 60 |
pipe = FluxPipeline.from_pretrained(
|
| 61 |
"black-forest-labs/FLUX.1-dev",
|
| 62 |
torch_dtype=torch.float16,
|
| 63 |
-
|
|
|
|
| 64 |
)
|
| 65 |
-
pipe.enable_attention_slicing(
|
|
|
|
| 66 |
|
| 67 |
-
# LoRA ๊ฐ์ค์น ๋ก๋
|
| 68 |
-
pipe.load_lora_weights(
|
| 69 |
-
hf_hub_download(
|
| 70 |
-
"ByteDance/Hyper-SD",
|
| 71 |
-
"Hyper-FLUX.1-dev-8steps-lora.safetensors",
|
| 72 |
-
token=HF_TOKEN
|
| 73 |
-
)
|
| 74 |
-
)
|
| 75 |
-
pipe.fuse_lora(lora_scale=0.125)
|
| 76 |
-
|
| 77 |
-
if torch.cuda.is_available():
|
| 78 |
-
pipe = pipe.to('cuda:0') # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
| 79 |
-
|
| 80 |
except Exception as e:
|
| 81 |
print(f"Model initialization error: {str(e)}")
|
| 82 |
raise
|
|
|
|
| 83 |
# GPU ์ค์
|
| 84 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
| 85 |
|
|
@@ -234,32 +226,29 @@ def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int,
|
|
| 234 |
return base_size * 4 // 3, base_size
|
| 235 |
return base_size, base_size
|
| 236 |
|
| 237 |
-
@spaces.GPU(duration=20)
|
| 238 |
def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
|
| 239 |
try:
|
| 240 |
width, height = calculate_dimensions(aspect_ratio)
|
| 241 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
| 242 |
|
| 243 |
-
|
|
|
|
| 244 |
if width > max_size or height > max_size:
|
| 245 |
ratio = max_size / max(width, height)
|
| 246 |
width = int(width * ratio)
|
| 247 |
height = int(height * ratio)
|
| 248 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
| 249 |
|
| 250 |
-
with
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
).images[0]
|
| 260 |
-
except Exception as e:
|
| 261 |
-
print(f"Pipeline error: {str(e)}")
|
| 262 |
-
return Image.new('RGB', (width, height), 'white')
|
| 263 |
|
| 264 |
return image
|
| 265 |
except Exception as e:
|
|
@@ -460,7 +449,7 @@ def update_box_button(img, box_input):
|
|
| 460 |
except:
|
| 461 |
return gr.update(interactive=False, variant="secondary")
|
| 462 |
|
| 463 |
-
def process_image(img: Image.Image, max_size: int =
|
| 464 |
"""์ด๋ฏธ์ง ํฌ๊ธฐ ์ต์ ํ"""
|
| 465 |
if img.width > max_size or img.height > max_size:
|
| 466 |
ratio = max_size / max(img.width, img.height)
|
|
@@ -824,14 +813,15 @@ if __name__ == "__main__":
|
|
| 824 |
queue=True
|
| 825 |
)
|
| 826 |
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
|
|
|
|
|
| 23 |
from PIL import Image, ImageDraw, ImageFont
|
| 24 |
|
| 25 |
def clear_memory():
|
| 26 |
+
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
|
| 27 |
gc.collect()
|
| 28 |
+
torch.cuda.empty_cache()
|
| 29 |
+
|
| 30 |
+
# ์ฌ์ฉํ์ง ์๋ ์บ์ ์ ๋ฆฌ
|
| 31 |
+
if hasattr(torch.cuda, 'empty_cache'):
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
+
if hasattr(torch.cuda, 'ipc_collect'):
|
| 34 |
+
torch.cuda.ipc_collect()
|
| 35 |
|
| 36 |
def initialize_models():
|
|
|
|
| 37 |
global segmenter, gd_model, gd_processor, pipe, translator
|
| 38 |
|
| 39 |
try:
|
|
|
|
| 40 |
clear_memory()
|
| 41 |
|
| 42 |
+
# CPU์์๋ง ์คํ๋๋ ๋ฒ์ญ ๋ชจ๋ธ
|
| 43 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,
|
| 44 |
+
low_cpu_mem_usage=True).to('cpu')
|
| 45 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 46 |
+
translator = pipeline("translation", model=model, tokenizer=tokenizer,
|
| 47 |
+
device=-1)
|
| 48 |
|
| 49 |
+
# GroundingDINO ๋ชจ๋ธ
|
| 50 |
gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
|
| 51 |
gd_model = GroundingDinoForObjectDetection.from_pretrained(
|
| 52 |
gd_model_path,
|
| 53 |
torch_dtype=torch.float16,
|
| 54 |
+
device_map='auto',
|
| 55 |
+
low_cpu_mem_usage=True
|
| 56 |
)
|
| 57 |
|
| 58 |
+
# Segmenter
|
| 59 |
+
segmenter = BoxSegmenter(device='cpu')
|
| 60 |
|
| 61 |
+
# FLUX ํ์ดํ๋ผ์ธ
|
| 62 |
pipe = FluxPipeline.from_pretrained(
|
| 63 |
"black-forest-labs/FLUX.1-dev",
|
| 64 |
torch_dtype=torch.float16,
|
| 65 |
+
device_map='auto',
|
| 66 |
+
low_cpu_mem_usage=True
|
| 67 |
)
|
| 68 |
+
pipe.enable_attention_slicing()
|
| 69 |
+
pipe.enable_model_cpu_offload()
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
except Exception as e:
|
| 72 |
print(f"Model initialization error: {str(e)}")
|
| 73 |
raise
|
| 74 |
+
|
| 75 |
# GPU ์ค์
|
| 76 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์
|
| 77 |
|
|
|
|
| 226 |
return base_size * 4 // 3, base_size
|
| 227 |
return base_size, base_size
|
| 228 |
|
| 229 |
+
@spaces.GPU(duration=20)
|
| 230 |
def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
|
| 231 |
try:
|
| 232 |
width, height = calculate_dimensions(aspect_ratio)
|
| 233 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
| 234 |
|
| 235 |
+
# ์ต๋ ํฌ๊ธฐ ์ ํ
|
| 236 |
+
max_size = 512 # 768์์ 512๋ก ๊ฐ์
|
| 237 |
if width > max_size or height > max_size:
|
| 238 |
ratio = max_size / max(width, height)
|
| 239 |
width = int(width * ratio)
|
| 240 |
height = int(height * ratio)
|
| 241 |
width, height = adjust_size_to_multiple_of_8(width, height)
|
| 242 |
|
| 243 |
+
with torch.inference_mode():
|
| 244 |
+
image = pipe(
|
| 245 |
+
prompt=prompt,
|
| 246 |
+
width=width,
|
| 247 |
+
height=height,
|
| 248 |
+
num_inference_steps=4, # 8์์ 4๋ก ๊ฐ์
|
| 249 |
+
guidance_scale=4.0,
|
| 250 |
+
batch_size=1
|
| 251 |
+
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
return image
|
| 254 |
except Exception as e:
|
|
|
|
| 449 |
except:
|
| 450 |
return gr.update(interactive=False, variant="secondary")
|
| 451 |
|
| 452 |
+
def process_image(img: Image.Image, max_size: int = 768) -> Image.Image:
|
| 453 |
"""์ด๋ฏธ์ง ํฌ๊ธฐ ์ต์ ํ"""
|
| 454 |
if img.width > max_size or img.height > max_size:
|
| 455 |
ratio = max_size / max(img.width, img.height)
|
|
|
|
| 813 |
queue=True
|
| 814 |
)
|
| 815 |
|
| 816 |
+
demo.launch(
|
| 817 |
+
server_name="0.0.0.0",
|
| 818 |
+
server_port=7860,
|
| 819 |
+
share=False,
|
| 820 |
+
max_threads=2,
|
| 821 |
+
enable_queue=True,
|
| 822 |
+
cache_examples=False,
|
| 823 |
+
show_error=True,
|
| 824 |
+
show_tips=False,
|
| 825 |
+
max_size=1, # ํ ํฌ๊ธฐ ์ ํ
|
| 826 |
+
memory_limit="48Gi" # ๋ฉ๋ชจ๋ฆฌ ์ ํ ์ค์
|
| 827 |
+
)
|