Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -20,25 +20,19 @@ from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor 
     | 
|
| 20 | 
         
             
            from diffusers import FluxPipeline
         
     | 
| 21 | 
         
             
            from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
         
     | 
| 22 | 
         
             
            import gc
         
     | 
| 23 | 
         
            -
            import base64
         
     | 
| 24 | 
         | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
            # GPU ์ค์ 
         
     | 
| 27 | 
         
            -
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์ 
         
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
            ###--------------ZERO GPU ํ์/ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ๊ณตํต --------------------###
         
     | 
| 30 | 
         
             
            def clear_memory():
         
     | 
| 31 | 
         
             
                """๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
         
     | 
| 32 | 
         
             
                gc.collect()
         
     | 
| 33 | 
         
            -
                 
     | 
| 34 | 
         
            -
                     
     | 
| 35 | 
         
            -
                        with torch.cuda.device( 
     | 
| 36 | 
         
             
                            torch.cuda.empty_cache()
         
     | 
| 37 | 
         
            -
             
     | 
| 38 | 
         
            -
                     
     | 
| 39 | 
         
            -
                        print(f"Warning: Could not clear CUDA memory: {e}")
         
     | 
| 40 | 
         | 
| 41 | 
         
            -
             
     | 
| 
         | 
|
| 42 | 
         | 
| 43 | 
         
             
            # GPU ์ค์ ์ try-except๋ก ๊ฐ์ธ๊ธฐ
         
     | 
| 44 | 
         
             
            if torch.cuda.is_available():
         
     | 
| 
         @@ -94,35 +88,14 @@ gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_ 
     | 
|
| 94 | 
         
             
            gd_model = gd_model.to(device=device)
         
     | 
| 95 | 
         
             
            assert isinstance(gd_model, GroundingDinoForObjectDetection)
         
     | 
| 96 | 
         | 
| 97 | 
         
            -
            # ํ์ดํ๋ผ์ธ ์ด๊ธฐํ 
     | 
| 98 | 
         
             
            pipe = FluxPipeline.from_pretrained(
         
     | 
| 99 | 
         
             
                "black-forest-labs/FLUX.1-dev",
         
     | 
| 100 | 
         
             
                torch_dtype=torch.float16,
         
     | 
| 101 | 
         
             
                use_auth_token=HF_TOKEN
         
     | 
| 102 | 
         
             
            )
         
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
            # ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ ์ค์  - FluxPipeline์์ ์ง์ํ๋ ๋ฉ์๋๋ง ์ฌ์ฉ
         
     | 
| 105 | 
         
             
            pipe.enable_attention_slicing(slice_size="auto")
         
     | 
| 106 | 
         | 
| 107 | 
         
            -
            # xformers ์ต์ ํ (์ค์น๋์ด ์๋ ๊ฒฝ์ฐ์๋ง)
         
     | 
| 108 | 
         
            -
            try:
         
     | 
| 109 | 
         
            -
                import xformers
         
     | 
| 110 | 
         
            -
                pipe.enable_xformers_memory_efficient_attention()
         
     | 
| 111 | 
         
            -
            except ImportError:
         
     | 
| 112 | 
         
            -
                print("xformers is not installed. Skipping memory efficient attention.")
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
            # GPU ์ค์ 
         
     | 
| 115 | 
         
            -
            if torch.cuda.is_available():
         
     | 
| 116 | 
         
            -
                try:
         
     | 
| 117 | 
         
            -
                    pipe = pipe.to("cuda:0")
         
     | 
| 118 | 
         
            -
                    # CPU ์คํ๋ก๋ฉ์ด ์ง์๋๋ ๊ฒฝ์ฐ์๋ง ํ์ฑํ
         
     | 
| 119 | 
         
            -
                    if hasattr(pipe, 'enable_model_cpu_offload'):
         
     | 
| 120 | 
         
            -
                        pipe.enable_model_cpu_offload()
         
     | 
| 121 | 
         
            -
                except Exception as e:
         
     | 
| 122 | 
         
            -
                    print(f"Warning: Could not move pipeline to CUDA: {str(e)}")
         
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
                
         
     | 
| 125 | 
         
            -
             
     | 
| 126 | 
         
             
            # LoRA ๊ฐ์ค์น ๋ก๋
         
     | 
| 127 | 
         
             
            pipe.load_lora_weights(
         
     | 
| 128 | 
         
             
                hf_hub_download(
         
     | 
| 
         @@ -194,66 +167,56 @@ def apply_mask(img: Image.Image, mask_img: Image.Image, defringe: bool = True) - 
     | 
|
| 194 | 
         
             
                return result
         
     | 
| 195 | 
         | 
| 196 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 197 | 
         
             
            def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
         
     | 
| 198 | 
         
             
                """์ ํ๋ ๋น์จ์ ๋ฐ๋ผ ์ด๋ฏธ์ง ํฌ๊ธฐ ๊ณ์ฐ"""
         
     | 
| 199 | 
         
            -
                # FLUX ํ์ดํ๋ผ์ธ์ด ์ง์ํ๋ ์์ ํ ํฌ๊ธฐ ์ฌ์ฉ
         
     | 
| 200 | 
         
             
                if aspect_ratio == "1:1":
         
     | 
| 201 | 
         
            -
                     
     | 
| 202 | 
         
             
                elif aspect_ratio == "16:9":
         
     | 
| 203 | 
         
            -
                     
     | 
| 204 | 
         
             
                elif aspect_ratio == "9:16":
         
     | 
| 205 | 
         
            -
                     
     | 
| 206 | 
         
             
                elif aspect_ratio == "4:3":
         
     | 
| 207 | 
         
            -
                     
     | 
| 208 | 
         
            -
                 
     | 
| 209 | 
         
            -
                    width = height = 512
         
     | 
| 210 | 
         
            -
                
         
     | 
| 211 | 
         
            -
                # 8์ ๋ฐฐ์๋ก ์กฐ์ 
         
     | 
| 212 | 
         
            -
                width = (width // 8) * 8
         
     | 
| 213 | 
         
            -
                height = (height // 8) * 8
         
     | 
| 214 | 
         
            -
                
         
     | 
| 215 | 
         
            -
                return width, height
         
     | 
| 216 | 
         | 
| 
         | 
|
| 217 | 
         
             
            def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
         
     | 
| 218 | 
         
             
                try:
         
     | 
| 219 | 
         
            -
                    # ์์ ํ ํฌ๊ธฐ ๊ณ์ฐ
         
     | 
| 220 | 
         
             
                    width, height = calculate_dimensions(aspect_ratio)
         
     | 
| 
         | 
|
| 221 | 
         | 
| 222 | 
         
            -
                     
     | 
| 223 | 
         
            -
                    
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 224 | 
         
             
                    with timer("Background generation"):
         
     | 
| 225 | 
         
             
                        try:
         
     | 
| 226 | 
         
            -
                            # ๋จผ์  512x512๋ก ์์ฑ
         
     | 
| 227 | 
         
             
                            with torch.inference_mode():
         
     | 
| 228 | 
         
             
                                image = pipe(
         
     | 
| 229 | 
         
             
                                    prompt=prompt,
         
     | 
| 230 | 
         
            -
                                    width= 
     | 
| 231 | 
         
            -
                                    height= 
     | 
| 232 | 
         
             
                                    num_inference_steps=8,
         
     | 
| 233 | 
         
            -
                                    guidance_scale=4.0 
     | 
| 234 | 
         
             
                                ).images[0]
         
     | 
| 235 | 
         
            -
                            
         
     | 
| 236 | 
         
            -
                            # ์ํ๋ ํฌ๊ธฐ๋ก ๋ฆฌ์ฌ์ด์ฆ
         
     | 
| 237 | 
         
            -
                            if width != 512 or height != 512:
         
     | 
| 238 | 
         
            -
                                image = image.resize((width, height), Image.LANCZOS)
         
     | 
| 239 | 
         
            -
                            return image
         
     | 
| 240 | 
         
            -
                            
         
     | 
| 241 | 
         
             
                        except Exception as e:
         
     | 
| 242 | 
         
             
                            print(f"Pipeline error: {str(e)}")
         
     | 
| 243 | 
         
            -
                            # ์๋ฌ ๋ฐ์ ์ ํฐ์ ๋ฐฐ๊ฒฝ ๋ฐํ
         
     | 
| 244 | 
         
             
                            return Image.new('RGB', (width, height), 'white')
         
     | 
| 245 | 
         | 
| 
         | 
|
| 246 | 
         
             
                except Exception as e:
         
     | 
| 247 | 
         
             
                    print(f"Background generation error: {str(e)}")
         
     | 
| 248 | 
         
             
                    return Image.new('RGB', (512, 512), 'white')
         
     | 
| 249 | 
         | 
| 250 | 
         
            -
             
     | 
| 251 | 
         
            -
            def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
         
     | 
| 252 | 
         
            -
                """์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 8์ ๋ฐฐ์๋ก ์กฐ์ """
         
     | 
| 253 | 
         
            -
                new_width = max(8, ((width + 7) // 8) * 8)  # ์ต์ 8ํฝ์
 ๋ณด์ฅ
         
     | 
| 254 | 
         
            -
                new_height = max(8, ((height + 7) // 8) * 8)  # ์ต์ 8ํฝ์
 ๋ณด์ฅ
         
     | 
| 255 | 
         
            -
                return new_width, new_height
         
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
             
            def create_position_grid():
         
     | 
| 258 | 
         
             
                return """
         
     | 
| 259 | 
         
             
                <div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
         
     | 
| 
         @@ -310,24 +273,26 @@ def combine_with_background(foreground: Image.Image, background: Image.Image, 
     | 
|
| 310 | 
         
             
                result.paste(scaled_foreground, (x, y), scaled_foreground)
         
     | 
| 311 | 
         
             
                return result
         
     | 
| 312 | 
         | 
| 313 | 
         
            -
            @spaces.GPU(duration= 
     | 
| 314 | 
         
             
            def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
         
     | 
| 
         | 
|
| 315 | 
         
             
                try:
         
     | 
| 316 | 
         
            -
                     
     | 
| 317 | 
         
            -
                         
     | 
| 318 | 
         
            -
             
     | 
| 319 | 
         
            -
             
     | 
| 320 | 
         
            -
             
     | 
| 321 | 
         
            -
             
     | 
| 322 | 
         
            -
                             
     | 
| 323 | 
         
            -
             
     | 
| 324 | 
         
            -
                         
     | 
| 325 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 326 | 
         
             
                except Exception as e:
         
     | 
| 327 | 
         
             
                    print(f"GPU process error: {str(e)}")
         
     | 
| 328 | 
         
             
                    raise
         
     | 
| 329 | 
         
            -
                finally:
         
     | 
| 330 | 
         
            -
                    clear_memory()
         
     | 
| 331 | 
         | 
| 332 | 
         
             
            def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
         
     | 
| 333 | 
         
             
                try:
         
     | 
| 
         @@ -338,12 +303,16 @@ def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str 
     | 
|
| 338 | 
         
             
                        new_size = (int(img.width * ratio), int(img.height * ratio))
         
     | 
| 339 | 
         
             
                        img = img.resize(new_size, Image.LANCZOS)
         
     | 
| 340 | 
         | 
| 341 | 
         
            -
                    # CUDA ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ
         
     | 
| 342 | 
         
            -
                     
     | 
| 343 | 
         
            -
                        torch.cuda. 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 344 | 
         | 
| 345 | 
         
            -
                     
     | 
| 346 | 
         
            -
                    with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
         
     | 
| 347 | 
         
             
                        mask, bbox, time_log = _gpu_process(img, prompt)
         
     | 
| 348 | 
         
             
                        masked_alpha = apply_mask(img, mask, defringe=True)
         
     | 
| 349 | 
         | 
| 
         @@ -376,42 +345,37 @@ def process_prompt(img: Image.Image, prompt: str, bg_prompt: str | None = None, 
     | 
|
| 376 | 
         
             
                              aspect_ratio: str = "1:1", position: str = "bottom-center", 
         
     | 
| 377 | 
         
             
                              scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
         
     | 
| 378 | 
         
             
                try:
         
     | 
| 379 | 
         
            -
                    if img is None or  
     | 
| 380 | 
         
             
                        raise gr.Error("Please provide both image and prompt")
         
     | 
| 381 | 
         | 
| 382 | 
         
            -
                    print(f"Processing with position: {position}, scale: {scale_percent}") 
     | 
| 383 | 
         | 
| 384 | 
         
            -
                     
     | 
| 385 | 
         
            -
             
     | 
| 386 | 
         
            -
             
     | 
| 387 | 
         
            -
             
     | 
| 388 | 
         
            -
             
     | 
| 
         | 
|
| 389 | 
         | 
| 390 | 
         
            -
                     
     | 
| 391 | 
         
            -
                    translated_prompt = translate_to_english(prompt)
         
     | 
| 392 | 
         
            -
                    translated_bg_prompt = translate_to_english(bg_prompt) if bg_prompt else None
         
     | 
| 393 | 
         | 
| 394 | 
         
            -
                     
     | 
| 395 | 
         
            -
             
     | 
| 396 | 
         
            -
             
     | 
| 397 | 
         
            -
             
     | 
| 398 | 
         
            -
             
     | 
| 399 | 
         
            -
             
     | 
| 400 | 
         
            -
                                 
     | 
| 401 | 
         
            -
             
     | 
| 402 | 
         
            -
             
     | 
| 403 | 
         
            -
             
     | 
| 404 | 
         
            -
             
     | 
| 405 | 
         
            -
             
     | 
| 406 | 
         
            -
             
     | 
| 407 | 
         
            -
             
     | 
| 408 | 
         
            -
             
     | 
| 409 | 
         
            -
                                return results[1], results[2]
         
     | 
| 410 | 
         
            -
                        
         
     | 
| 411 | 
         
            -
                        return results[1], results[2]
         
     | 
| 412 | 
         
            -
                        
         
     | 
| 413 | 
         
             
                except Exception as e:
         
     | 
| 414 | 
         
            -
                    print(f" 
     | 
| 415 | 
         
             
                    raise gr.Error(str(e))
         
     | 
| 416 | 
         
             
                finally:
         
     | 
| 417 | 
         
             
                    clear_memory()
         
     | 
| 
         @@ -518,61 +482,9 @@ button.primary:hover { 
     | 
|
| 518 | 
         
             
            }
         
     | 
| 519 | 
         
             
            """
         
     | 
| 520 | 
         | 
| 
         | 
|
| 
         | 
|
| 521 | 
         | 
| 522 | 
         
            -
            def get_image_base64(image_path):
         
     | 
| 523 | 
         
            -
                with open(image_path, "rb") as image_file:
         
     | 
| 524 | 
         
            -
                    return base64.b64encode(image_file.read()).decode()
         
     | 
| 525 | 
         
            -
             
     | 
| 526 | 
         
            -
            # ์ด๋ฏธ์ง๋ฅผ Base64๋ก ๋ณํ
         
     | 
| 527 | 
         
            -
            try:
         
     | 
| 528 | 
         
            -
                example_img1 = get_image_base64("aa1.png")
         
     | 
| 529 | 
         
            -
                example_img2 = get_image_base64("aa2.png")
         
     | 
| 530 | 
         
            -
                example_img3 = get_image_base64("aa3.png")
         
     | 
| 531 | 
         
            -
            except Exception as e:
         
     | 
| 532 | 
         
            -
                print(f"Error loading example images: {e}")
         
     | 
| 533 | 
         
            -
                example_img1 = example_img2 = example_img3 = ""
         
     | 
| 534 | 
         
            -
             
     | 
| 535 | 
         
            -
            # HTML ํ
ํ๋ฆฟ ์์ 
         
     | 
| 536 | 
         
            -
            example_html = f"""
         
     | 
| 537 | 
         
            -
            <div style="margin-top: 50px; padding: 20px; background-color: #f8f9fa; border-radius: 10px;">
         
     | 
| 538 | 
         
            -
                <h2 style="text-align: center; color: #2196F3; margin-bottom: 30px;">How It Works: Step by Step Guide</h2>
         
     | 
| 539 | 
         
            -
                
         
     | 
| 540 | 
         
            -
                <div style="display: flex; justify-content: space-around; align-items: center; flex-wrap: wrap; gap: 20px;">
         
     | 
| 541 | 
         
            -
                    <div style="text-align: center; flex: 1; min-width: 250px; max-width: 300px;">
         
     | 
| 542 | 
         
            -
                        <img src="data:image/png;base64,{example_img1}" 
         
     | 
| 543 | 
         
            -
                             style="width: 100%; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
         
     | 
| 544 | 
         
            -
                        <h3 style="color: #333; margin: 15px 0;">Step 1: Original Image</h3>
         
     | 
| 545 | 
         
            -
                        <p style="color: #666;">Upload your original image containing the object you want to extract.</p>
         
     | 
| 546 | 
         
            -
                    </div>
         
     | 
| 547 | 
         
            -
                    
         
     | 
| 548 | 
         
            -
                    <div style="text-align: center; flex: 1; min-width: 250px; max-width: 300px;">
         
     | 
| 549 | 
         
            -
                        <img src="data:image/png;base64,{example_img2}" 
         
     | 
| 550 | 
         
            -
                             style="width: 100%; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
         
     | 
| 551 | 
         
            -
                        <h3 style="color: #333; margin: 15px 0;">Step 2: Object Extraction</h3>
         
     | 
| 552 | 
         
            -
                        <p style="color: #666;">AI automatically detects and extracts the specified object.</p>
         
     | 
| 553 | 
         
            -
                    </div>
         
     | 
| 554 | 
         
            -
                    
         
     | 
| 555 | 
         
            -
                    <div style="text-align: center; flex: 1; min-width: 250px; max-width: 300px;">
         
     | 
| 556 | 
         
            -
                        <img src="data:image/png;base64,{example_img3}" 
         
     | 
| 557 | 
         
            -
                             style="width: 100%; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
         
     | 
| 558 | 
         
            -
                        <h3 style="color: #333; margin: 15px 0;">Step 3: Final Result</h3>
         
     | 
| 559 | 
         
            -
                        <p style="color: #666;">The extracted object is placed on an AI-generated background.</p>
         
     | 
| 560 | 
         
            -
                    </div>
         
     | 
| 561 | 
         
            -
                </div>
         
     | 
| 562 | 
         
            -
             
     | 
| 563 | 
         
            -
                <div style="margin-top: 30px; text-align: center; padding: 20px; background-color: #e3f2fd; border-radius: 8px;">
         
     | 
| 564 | 
         
            -
                    <h4 style="color: #1976D2; margin-bottom: 10px;">Key Features:</h4>
         
     | 
| 565 | 
         
            -
                    <ul style="list-style: none; padding: 0;">
         
     | 
| 566 | 
         
            -
                        <li style="margin: 5px 0;">โจ Advanced AI-powered object detection and extraction</li>
         
     | 
| 567 | 
         
            -
                        <li style="margin: 5px 0;">๐จ Custom background generation with text prompts</li>
         
     | 
| 568 | 
         
            -
                        <li style="margin: 5px 0;">๐ Flexible object positioning and sizing options</li>
         
     | 
| 569 | 
         
            -
                        <li style="margin: 5px 0;">๐ Multiple aspect ratio support for various use cases</li>
         
     | 
| 570 | 
         
            -
                    </ul>
         
     | 
| 571 | 
         
            -
                </div>
         
     | 
| 572 | 
         
            -
            </div>
         
     | 
| 573 | 
         
            -
            """
         
     | 
| 574 | 
         
            -
             
     | 
| 575 | 
         
            -
                
         
     | 
| 576 | 
         
             
            # UI ๊ตฌ์ฑ
         
     | 
| 577 | 
         
             
            with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
         
     | 
| 578 | 
         
             
                gr.HTML("""
         
     | 
| 
         @@ -582,9 +494,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: 
     | 
|
| 582 | 
         
             
                    </div>
         
     | 
| 583 | 
         
             
                """)
         
     | 
| 584 | 
         | 
| 585 | 
         
            -
                # ์์  ์น์
 ์ถ๊ฐ
         
     | 
| 586 | 
         
            -
                gr.HTML(example_html)
         
     | 
| 587 | 
         
            -
                
         
     | 
| 588 | 
         
             
                with gr.Row():
         
     | 
| 589 | 
         
             
                    with gr.Column(scale=1):
         
     | 
| 590 | 
         
             
                        input_image = gr.Image(
         
     | 
| 
         @@ -632,7 +541,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: 
     | 
|
| 632 | 
         
             
                                scale_slider = gr.Slider(
         
     | 
| 633 | 
         
             
                                    minimum=10,
         
     | 
| 634 | 
         
             
                                    maximum=200,
         
     | 
| 635 | 
         
            -
                                    value= 
     | 
| 636 | 
         
             
                                    step=5,
         
     | 
| 637 | 
         
             
                                    label="Object Size (%)"
         
     | 
| 638 | 
         
             
                                )
         
     | 
| 
         @@ -689,13 +598,12 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: 
     | 
|
| 689 | 
         
             
                )
         
     | 
| 690 | 
         | 
| 691 | 
         
             
                def update_controls(bg_prompt):
         
     | 
| 
         | 
|
| 692 | 
         
             
                    is_visible = bool(bg_prompt)
         
     | 
| 693 | 
         
             
                    return [
         
     | 
| 694 | 
         
            -
                        gr.update(visible=is_visible 
     | 
| 695 | 
         
             
                        gr.update(visible=is_visible),  # object_controls
         
     | 
| 696 | 
         
             
                    ]
         
     | 
| 697 | 
         
            -
                
         
     | 
| 698 | 
         
            -
             
     | 
| 699 | 
         | 
| 700 | 
         
             
                bg_prompt.change(
         
     | 
| 701 | 
         
             
                    fn=update_controls,
         
     | 
| 
         @@ -724,4 +632,5 @@ demo.launch( 
     | 
|
| 724 | 
         
             
                server_name="0.0.0.0",
         
     | 
| 725 | 
         
             
                server_port=7860,
         
     | 
| 726 | 
         
             
                share=False,
         
     | 
| 727 | 
         
            -
                max_threads=2 
     | 
| 
         | 
| 
         | 
|
| 20 | 
         
             
            from diffusers import FluxPipeline
         
     | 
| 21 | 
         
             
            from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
         
     | 
| 22 | 
         
             
            import gc
         
     | 
| 
         | 
|
| 23 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
            def clear_memory():
         
     | 
| 25 | 
         
             
                """๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ํจ์"""
         
     | 
| 26 | 
         
             
                gc.collect()
         
     | 
| 27 | 
         
            +
                try:
         
     | 
| 28 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 29 | 
         
            +
                        with torch.cuda.device(0):  # ๋ช
์์ ์ผ๋ก device 0 ์ฌ์ฉ
         
     | 
| 30 | 
         
             
                            torch.cuda.empty_cache()
         
     | 
| 31 | 
         
            +
                except:
         
     | 
| 32 | 
         
            +
                    pass
         
     | 
| 
         | 
|
| 33 | 
         | 
| 34 | 
         
            +
            # GPU ์ค์ 
         
     | 
| 35 | 
         
            +
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # ๋ช
์์ ์ผ๋ก cuda:0 ์ง์ 
         
     | 
| 36 | 
         | 
| 37 | 
         
             
            # GPU ์ค์ ์ try-except๋ก ๊ฐ์ธ๊ธฐ
         
     | 
| 38 | 
         
             
            if torch.cuda.is_available():
         
     | 
| 
         | 
|
| 88 | 
         
             
            gd_model = gd_model.to(device=device)
         
     | 
| 89 | 
         
             
            assert isinstance(gd_model, GroundingDinoForObjectDetection)
         
     | 
| 90 | 
         | 
| 91 | 
         
            +
            # FLUX ํ์ดํ๋ผ์ธ ์ด๊ธฐํ
         
     | 
| 92 | 
         
             
            pipe = FluxPipeline.from_pretrained(
         
     | 
| 93 | 
         
             
                "black-forest-labs/FLUX.1-dev",
         
     | 
| 94 | 
         
             
                torch_dtype=torch.float16,
         
     | 
| 95 | 
         
             
                use_auth_token=HF_TOKEN
         
     | 
| 96 | 
         
             
            )
         
     | 
| 
         | 
|
| 
         | 
|
| 97 | 
         
             
            pipe.enable_attention_slicing(slice_size="auto")
         
     | 
| 98 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 99 | 
         
             
            # LoRA ๊ฐ์ค์น ๋ก๋
         
     | 
| 100 | 
         
             
            pipe.load_lora_weights(
         
     | 
| 101 | 
         
             
                hf_hub_download(
         
     | 
| 
         | 
|
| 167 | 
         
             
                return result
         
     | 
| 168 | 
         | 
| 169 | 
         | 
| 170 | 
         
            +
            def adjust_size_to_multiple_of_8(width: int, height: int) -> tuple[int, int]:
         
     | 
| 171 | 
         
            +
                """์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ 8์ ๋ฐฐ์๋ก ์กฐ์ ํ๋ ํจ์"""
         
     | 
| 172 | 
         
            +
                new_width = ((width + 7) // 8) * 8
         
     | 
| 173 | 
         
            +
                new_height = ((height + 7) // 8) * 8
         
     | 
| 174 | 
         
            +
                return new_width, new_height
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
             
            def calculate_dimensions(aspect_ratio: str, base_size: int = 512) -> tuple[int, int]:
         
     | 
| 177 | 
         
             
                """์ ํ๋ ๋น์จ์ ๋ฐ๋ผ ์ด๋ฏธ์ง ํฌ๊ธฐ ๊ณ์ฐ"""
         
     | 
| 
         | 
|
| 178 | 
         
             
                if aspect_ratio == "1:1":
         
     | 
| 179 | 
         
            +
                    return base_size, base_size
         
     | 
| 180 | 
         
             
                elif aspect_ratio == "16:9":
         
     | 
| 181 | 
         
            +
                    return base_size * 16 // 9, base_size
         
     | 
| 182 | 
         
             
                elif aspect_ratio == "9:16":
         
     | 
| 183 | 
         
            +
                    return base_size, base_size * 16 // 9
         
     | 
| 184 | 
         
             
                elif aspect_ratio == "4:3":
         
     | 
| 185 | 
         
            +
                    return base_size * 4 // 3, base_size
         
     | 
| 186 | 
         
            +
                return base_size, base_size
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 187 | 
         | 
| 188 | 
         
            +
            @spaces.GPU(duration=20)  # 40์ด์์ 20์ด๋ก ๊ฐ์
         
     | 
| 189 | 
         
             
            def generate_background(prompt: str, aspect_ratio: str) -> Image.Image:
         
     | 
| 190 | 
         
             
                try:
         
     | 
| 
         | 
|
| 191 | 
         
             
                    width, height = calculate_dimensions(aspect_ratio)
         
     | 
| 192 | 
         
            +
                    width, height = adjust_size_to_multiple_of_8(width, height)
         
     | 
| 193 | 
         | 
| 194 | 
         
            +
                    max_size = 768
         
     | 
| 195 | 
         
            +
                    if width > max_size or height > max_size:
         
     | 
| 196 | 
         
            +
                        ratio = max_size / max(width, height)
         
     | 
| 197 | 
         
            +
                        width = int(width * ratio)
         
     | 
| 198 | 
         
            +
                        height = int(height * ratio)
         
     | 
| 199 | 
         
            +
                        width, height = adjust_size_to_multiple_of_8(width, height)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
             
                    with timer("Background generation"):
         
     | 
| 202 | 
         
             
                        try:
         
     | 
| 
         | 
|
| 203 | 
         
             
                            with torch.inference_mode():
         
     | 
| 204 | 
         
             
                                image = pipe(
         
     | 
| 205 | 
         
             
                                    prompt=prompt,
         
     | 
| 206 | 
         
            +
                                    width=width,
         
     | 
| 207 | 
         
            +
                                    height=height,
         
     | 
| 208 | 
         
             
                                    num_inference_steps=8,
         
     | 
| 209 | 
         
            +
                                    guidance_scale=4.0
         
     | 
| 210 | 
         
             
                                ).images[0]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 211 | 
         
             
                        except Exception as e:
         
     | 
| 212 | 
         
             
                            print(f"Pipeline error: {str(e)}")
         
     | 
| 
         | 
|
| 213 | 
         
             
                            return Image.new('RGB', (width, height), 'white')
         
     | 
| 214 | 
         | 
| 215 | 
         
            +
                    return image
         
     | 
| 216 | 
         
             
                except Exception as e:
         
     | 
| 217 | 
         
             
                    print(f"Background generation error: {str(e)}")
         
     | 
| 218 | 
         
             
                    return Image.new('RGB', (512, 512), 'white')
         
     | 
| 219 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 220 | 
         
             
            def create_position_grid():
         
     | 
| 221 | 
         
             
                return """
         
     | 
| 222 | 
         
             
                <div class="position-grid" style="display: grid; grid-template-columns: repeat(3, 1fr); gap: 10px; width: 150px; margin: auto;">
         
     | 
| 
         | 
|
| 273 | 
         
             
                result.paste(scaled_foreground, (x, y), scaled_foreground)
         
     | 
| 274 | 
         
             
                return result
         
     | 
| 275 | 
         | 
| 276 | 
         
            +
            @spaces.GPU(duration=30)  # 120์ด์์ 30์ด๋ก ๊ฐ์
         
     | 
| 277 | 
         
             
            def _gpu_process(img: Image.Image, prompt: str | BoundingBox | None) -> tuple[Image.Image, BoundingBox | None, list[str]]:
         
     | 
| 278 | 
         
            +
                time_log: list[str] = []
         
     | 
| 279 | 
         
             
                try:
         
     | 
| 280 | 
         
            +
                    if isinstance(prompt, str):
         
     | 
| 281 | 
         
            +
                        t0 = time.time()
         
     | 
| 282 | 
         
            +
                        bbox = gd_detect(img, prompt)
         
     | 
| 283 | 
         
            +
                        time_log.append(f"detect: {time.time() - t0}")
         
     | 
| 284 | 
         
            +
                        if not bbox:
         
     | 
| 285 | 
         
            +
                            print(time_log[0])
         
     | 
| 286 | 
         
            +
                            raise gr.Error("No object detected")
         
     | 
| 287 | 
         
            +
                    else:
         
     | 
| 288 | 
         
            +
                        bbox = prompt
         
     | 
| 289 | 
         
            +
                    t0 = time.time()
         
     | 
| 290 | 
         
            +
                    mask = segmenter(img, bbox)
         
     | 
| 291 | 
         
            +
                    time_log.append(f"segment: {time.time() - t0}")
         
     | 
| 292 | 
         
            +
                    return mask, bbox, time_log
         
     | 
| 293 | 
         
             
                except Exception as e:
         
     | 
| 294 | 
         
             
                    print(f"GPU process error: {str(e)}")
         
     | 
| 295 | 
         
             
                    raise
         
     | 
| 
         | 
|
| 
         | 
|
| 296 | 
         | 
| 297 | 
         
             
            def _process(img: Image.Image, prompt: str | BoundingBox | None, bg_prompt: str | None = None, aspect_ratio: str = "1:1") -> tuple[tuple[Image.Image, Image.Image, Image.Image], gr.DownloadButton]:
         
     | 
| 298 | 
         
             
                try:
         
     | 
| 
         | 
|
| 303 | 
         
             
                        new_size = (int(img.width * ratio), int(img.height * ratio))
         
     | 
| 304 | 
         
             
                        img = img.resize(new_size, Image.LANCZOS)
         
     | 
| 305 | 
         | 
| 306 | 
         
            +
                    # CUDA ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์์ 
         
     | 
| 307 | 
         
            +
                    try:
         
     | 
| 308 | 
         
            +
                        if torch.cuda.is_available():
         
     | 
| 309 | 
         
            +
                            current_device = torch.cuda.current_device()
         
     | 
| 310 | 
         
            +
                            with torch.cuda.device(current_device):
         
     | 
| 311 | 
         
            +
                                torch.cuda.empty_cache()
         
     | 
| 312 | 
         
            +
                    except Exception as e:
         
     | 
| 313 | 
         
            +
                        print(f"CUDA memory management failed: {e}")
         
     | 
| 314 | 
         | 
| 315 | 
         
            +
                    with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
         
     | 
| 
         | 
|
| 316 | 
         
             
                        mask, bbox, time_log = _gpu_process(img, prompt)
         
     | 
| 317 | 
         
             
                        masked_alpha = apply_mask(img, mask, defringe=True)
         
     | 
| 318 | 
         | 
| 
         | 
|
| 345 | 
         
             
                              aspect_ratio: str = "1:1", position: str = "bottom-center", 
         
     | 
| 346 | 
         
             
                              scale_percent: float = 100) -> tuple[Image.Image, Image.Image]:
         
     | 
| 347 | 
         
             
                try:
         
     | 
| 348 | 
         
            +
                    if img is None or prompt.strip() == "":
         
     | 
| 349 | 
         
             
                        raise gr.Error("Please provide both image and prompt")
         
     | 
| 350 | 
         | 
| 351 | 
         
            +
                    print(f"Processing with position: {position}, scale: {scale_percent}")
         
     | 
| 352 | 
         | 
| 353 | 
         
            +
                    try:
         
     | 
| 354 | 
         
            +
                        prompt = translate_to_english(prompt)
         
     | 
| 355 | 
         
            +
                        if bg_prompt:
         
     | 
| 356 | 
         
            +
                            bg_prompt = translate_to_english(bg_prompt)
         
     | 
| 357 | 
         
            +
                    except Exception as e:
         
     | 
| 358 | 
         
            +
                        print(f"Translation error (continuing with original text): {str(e)}")
         
     | 
| 359 | 
         | 
| 360 | 
         
            +
                    results, _ = _process(img, prompt, bg_prompt, aspect_ratio)
         
     | 
| 
         | 
|
| 
         | 
|
| 361 | 
         | 
| 362 | 
         
            +
                    if bg_prompt:
         
     | 
| 363 | 
         
            +
                        try:
         
     | 
| 364 | 
         
            +
                            combined = combine_with_background(
         
     | 
| 365 | 
         
            +
                                foreground=results[2],
         
     | 
| 366 | 
         
            +
                                background=results[1],
         
     | 
| 367 | 
         
            +
                                position=position,
         
     | 
| 368 | 
         
            +
                                scale_percent=scale_percent
         
     | 
| 369 | 
         
            +
                            )
         
     | 
| 370 | 
         
            +
                            print(f"Combined image created with position: {position}")
         
     | 
| 371 | 
         
            +
                            return combined, results[2]
         
     | 
| 372 | 
         
            +
                        except Exception as e:
         
     | 
| 373 | 
         
            +
                            print(f"Combination error: {str(e)}")
         
     | 
| 374 | 
         
            +
                            return results[1], results[2]
         
     | 
| 375 | 
         
            +
                    
         
     | 
| 376 | 
         
            +
                    return results[1], results[2]
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 377 | 
         
             
                except Exception as e:
         
     | 
| 378 | 
         
            +
                    print(f"Error in process_prompt: {str(e)}")
         
     | 
| 379 | 
         
             
                    raise gr.Error(str(e))
         
     | 
| 380 | 
         
             
                finally:
         
     | 
| 381 | 
         
             
                    clear_memory()
         
     | 
| 
         | 
|
| 482 | 
         
             
            }
         
     | 
| 483 | 
         
             
            """
         
     | 
| 484 | 
         | 
| 485 | 
         
            +
            # UI ๊ตฌ์ฑ
         
     | 
| 486 | 
         
            +
            # UI ๊ตฌ์ฑ ๋ถ๋ถ์์ process_btn์ ์๋ก ์ด๋ํ๊ณ  position_grid.click ๋ถ๋ถ ์ ๊ฑฐ
         
     | 
| 487 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 488 | 
         
             
            # UI ๊ตฌ์ฑ
         
     | 
| 489 | 
         
             
            with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
         
     | 
| 490 | 
         
             
                gr.HTML("""
         
     | 
| 
         | 
|
| 494 | 
         
             
                    </div>
         
     | 
| 495 | 
         
             
                """)
         
     | 
| 496 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 497 | 
         
             
                with gr.Row():
         
     | 
| 498 | 
         
             
                    with gr.Column(scale=1):
         
     | 
| 499 | 
         
             
                        input_image = gr.Image(
         
     | 
| 
         | 
|
| 541 | 
         
             
                                scale_slider = gr.Slider(
         
     | 
| 542 | 
         
             
                                    minimum=10,
         
     | 
| 543 | 
         
             
                                    maximum=200,
         
     | 
| 544 | 
         
            +
                                    value=50,
         
     | 
| 545 | 
         
             
                                    step=5,
         
     | 
| 546 | 
         
             
                                    label="Object Size (%)"
         
     | 
| 547 | 
         
             
                                )
         
     | 
| 
         | 
|
| 598 | 
         
             
                )
         
     | 
| 599 | 
         | 
| 600 | 
         
             
                def update_controls(bg_prompt):
         
     | 
| 601 | 
         
            +
                    """๋ฐฐ๊ฒฝ ํ๋กฌํํธ ์
๋ ฅ ์ฌ๋ถ์ ๋ฐ๋ผ ์ปจํธ๋กค ํ์ ์
๋ฐ์ดํธ"""
         
     | 
| 602 | 
         
             
                    is_visible = bool(bg_prompt)
         
     | 
| 603 | 
         
             
                    return [
         
     | 
| 604 | 
         
            +
                        gr.update(visible=is_visible),  # aspect_ratio
         
     | 
| 605 | 
         
             
                        gr.update(visible=is_visible),  # object_controls
         
     | 
| 606 | 
         
             
                    ]
         
     | 
| 
         | 
|
| 
         | 
|
| 607 | 
         | 
| 608 | 
         
             
                bg_prompt.change(
         
     | 
| 609 | 
         
             
                    fn=update_controls,
         
     | 
| 
         | 
|
| 632 | 
         
             
                server_name="0.0.0.0",
         
     | 
| 633 | 
         
             
                server_port=7860,
         
     | 
| 634 | 
         
             
                share=False,
         
     | 
| 635 | 
         
            +
                max_threads=2  # ์ค๋ ๋ ์ ์ ํ
         
     | 
| 636 | 
         
            +
            )
         
     |