Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -19,11 +19,50 @@ from trellis.representations import Gaussian, MeshExtractResult | |
| 19 | 
             
            from trellis.utils import render_utils, postprocessing_utils
         | 
| 20 | 
             
            from diffusers import FluxPipeline
         | 
| 21 | 
             
            from typing import Tuple, Dict, Any  # Tuple import ์ถ๊ฐ
         | 
| 22 | 
            -
             | 
| 23 | 
            -
            # ํ์ผ ์๋จ์ import ๋ฌธ ์์ 
         | 
| 24 | 
             
            import transformers
         | 
| 25 | 
             
            from transformers import pipeline as transformers_pipeline
         | 
| 26 | 
            -
            from transformers import Pipeline | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 27 | 
             
            # CUDA ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์ค์ 
         | 
| 28 | 
             
            torch.cuda.empty_cache()
         | 
| 29 | 
             
            torch.backends.cuda.matmul.allow_tf32 = True
         | 
| @@ -71,7 +110,7 @@ class timer: | |
| 71 |  | 
| 72 | 
             
            def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
         | 
| 73 | 
             
                trial_id = str(uuid.uuid4())
         | 
| 74 | 
            -
                processed_image =  | 
| 75 | 
             
                processed_image.save(f"{TMP_DIR}/{trial_id}.png")
         | 
| 76 | 
             
                return trial_id, processed_image
         | 
| 77 |  | 
| @@ -169,7 +208,7 @@ def text_to_image(prompt: str, height: int, width: int, steps: int, scales: floa | |
| 169 |  | 
| 170 | 
             
                # ํ๋กฌํํธ ์ ์ฒ๋ฆฌ
         | 
| 171 | 
             
                if contains_korean(prompt):
         | 
| 172 | 
            -
                    translated = translator(prompt)[0]['translation_text']
         | 
| 173 | 
             
                    prompt = translated
         | 
| 174 |  | 
| 175 | 
             
                # ํ๋กฌํํธ ํ์ ๊ฐ์ 
         | 
| @@ -177,7 +216,7 @@ def text_to_image(prompt: str, height: int, width: int, steps: int, scales: floa | |
| 177 |  | 
| 178 | 
             
                with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
         | 
| 179 | 
             
                    try:
         | 
| 180 | 
            -
                        generated_image = flux_pipe( | 
| 181 | 
             
                            prompt=[formatted_prompt],
         | 
| 182 | 
             
                            generator=torch.Generator().manual_seed(int(seed)),
         | 
| 183 | 
             
                            num_inference_steps=int(steps),
         | 
| @@ -330,35 +369,8 @@ if __name__ == "__main__": | |
| 330 | 
             
                print(f"Using device: {device}")
         | 
| 331 |  | 
| 332 | 
             
                try:
         | 
| 333 | 
            -
                    #  | 
| 334 | 
            -
                     | 
| 335 | 
            -
                        "JeffreyXiang/TRELLIS-image-large"
         | 
| 336 | 
            -
                    )
         | 
| 337 | 
            -
                    trellis_pipeline.to(device)
         | 
| 338 | 
            -
                    
         | 
| 339 | 
            -
                    # ์ด๋ฏธ์ง ์์ฑ ํ์ดํ๋ผ์ธ
         | 
| 340 | 
            -
                    flux_pipe = FluxPipeline.from_pretrained(
         | 
| 341 | 
            -
                        "black-forest-labs/FLUX.1-dev",
         | 
| 342 | 
            -
                        torch_dtype=torch.bfloat16,
         | 
| 343 | 
            -
                        device_map="balanced"
         | 
| 344 | 
            -
                    )
         | 
| 345 | 
            -
                    
         | 
| 346 | 
            -
                    # Hyper-SD LoRA ๋ก๋
         | 
| 347 | 
            -
                    lora_path = hf_hub_download(
         | 
| 348 | 
            -
                        "ByteDance/Hyper-SD",
         | 
| 349 | 
            -
                        "Hyper-FLUX.1-dev-8steps-lora.safetensors",
         | 
| 350 | 
            -
                        use_auth_token=HF_TOKEN
         | 
| 351 | 
            -
                    )
         | 
| 352 | 
            -
                    flux_pipe.load_lora_weights(lora_path)
         | 
| 353 | 
            -
                    flux_pipe.fuse_lora(lora_scale=0.125)
         | 
| 354 | 
            -
                    
         | 
| 355 | 
            -
                    # ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ
         | 
| 356 | 
            -
                    global translator
         | 
| 357 | 
            -
                    translator = transformers_pipeline(
         | 
| 358 | 
            -
                        "translation", 
         | 
| 359 | 
            -
                        model="Helsinki-NLP/opus-mt-ko-en",
         | 
| 360 | 
            -
                        device=device
         | 
| 361 | 
            -
                    )
         | 
| 362 |  | 
| 363 | 
             
                    # CUDA ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ
         | 
| 364 | 
             
                    if torch.cuda.is_available():
         | 
| @@ -367,7 +379,7 @@ if __name__ == "__main__": | |
| 367 | 
             
                    # ์ด๊ธฐ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํ
์คํธ
         | 
| 368 | 
             
                    try:
         | 
| 369 | 
             
                        test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
         | 
| 370 | 
            -
                        trellis_pipeline.preprocess_image(test_image)
         | 
| 371 | 
             
                    except Exception as e:
         | 
| 372 | 
             
                        print(f"Warning: Initial preprocessing test failed: {e}")
         | 
| 373 |  | 
|  | |
| 19 | 
             
            from trellis.utils import render_utils, postprocessing_utils
         | 
| 20 | 
             
            from diffusers import FluxPipeline
         | 
| 21 | 
             
            from typing import Tuple, Dict, Any  # Tuple import ์ถ๊ฐ
         | 
| 22 | 
            +
            # ํ์ผ ์๋จ์ import ๋ฌธ
         | 
|  | |
| 23 | 
             
            import transformers
         | 
| 24 | 
             
            from transformers import pipeline as transformers_pipeline
         | 
| 25 | 
            +
            from transformers import Pipeline
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            # ์ ์ญ ๋ณ์ ์ด๊ธฐํ
         | 
| 28 | 
            +
            class GlobalVars:
         | 
| 29 | 
            +
                def __init__(self):
         | 
| 30 | 
            +
                    self.translator = None
         | 
| 31 | 
            +
                    self.trellis_pipeline = None
         | 
| 32 | 
            +
                    self.flux_pipe = None
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            g = GlobalVars()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            def initialize_models(device):
         | 
| 37 | 
            +
                # 3D ์์ฑ ํ์ดํ๋ผ์ธ
         | 
| 38 | 
            +
                g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
         | 
| 39 | 
            +
                    "JeffreyXiang/TRELLIS-image-large"
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                g.trellis_pipeline.to(device)
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                # ์ด๋ฏธ์ง ์์ฑ ํ์ดํ๋ผ์ธ
         | 
| 44 | 
            +
                g.flux_pipe = FluxPipeline.from_pretrained(
         | 
| 45 | 
            +
                    "black-forest-labs/FLUX.1-dev",
         | 
| 46 | 
            +
                    torch_dtype=torch.bfloat16,
         | 
| 47 | 
            +
                    device_map="balanced"
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                # Hyper-SD LoRA ๋ก๋
         | 
| 51 | 
            +
                lora_path = hf_hub_download(
         | 
| 52 | 
            +
                    "ByteDance/Hyper-SD",
         | 
| 53 | 
            +
                    "Hyper-FLUX.1-dev-8steps-lora.safetensors",
         | 
| 54 | 
            +
                    use_auth_token=HF_TOKEN
         | 
| 55 | 
            +
                )
         | 
| 56 | 
            +
                g.flux_pipe.load_lora_weights(lora_path)
         | 
| 57 | 
            +
                g.flux_pipe.fuse_lora(lora_scale=0.125)
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                # ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ
         | 
| 60 | 
            +
                g.translator = transformers_pipeline(
         | 
| 61 | 
            +
                    "translation", 
         | 
| 62 | 
            +
                    model="Helsinki-NLP/opus-mt-ko-en",
         | 
| 63 | 
            +
                    device=device
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
                
         | 
| 66 | 
             
            # CUDA ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์ค์ 
         | 
| 67 | 
             
            torch.cuda.empty_cache()
         | 
| 68 | 
             
            torch.backends.cuda.matmul.allow_tf32 = True
         | 
|  | |
| 110 |  | 
| 111 | 
             
            def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
         | 
| 112 | 
             
                trial_id = str(uuid.uuid4())
         | 
| 113 | 
            +
                processed_image = g.trellis_pipeline.preprocess_image(image)
         | 
| 114 | 
             
                processed_image.save(f"{TMP_DIR}/{trial_id}.png")
         | 
| 115 | 
             
                return trial_id, processed_image
         | 
| 116 |  | 
|  | |
| 208 |  | 
| 209 | 
             
                # ํ๋กฌํํธ ์ ์ฒ๋ฆฌ
         | 
| 210 | 
             
                if contains_korean(prompt):
         | 
| 211 | 
            +
                    translated = g.translator(prompt)[0]['translation_text']
         | 
| 212 | 
             
                    prompt = translated
         | 
| 213 |  | 
| 214 | 
             
                # ํ๋กฌํํธ ํ์ ๊ฐ์ 
         | 
|  | |
| 216 |  | 
| 217 | 
             
                with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
         | 
| 218 | 
             
                    try:
         | 
| 219 | 
            +
                        generated_image = g.flux_pipe(
         | 
| 220 | 
             
                            prompt=[formatted_prompt],
         | 
| 221 | 
             
                            generator=torch.Generator().manual_seed(int(seed)),
         | 
| 222 | 
             
                            num_inference_steps=int(steps),
         | 
|  | |
| 369 | 
             
                print(f"Using device: {device}")
         | 
| 370 |  | 
| 371 | 
             
                try:
         | 
| 372 | 
            +
                    # ๋ชจ๋ธ ์ด๊ธฐํ
         | 
| 373 | 
            +
                    initialize_models(device)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 374 |  | 
| 375 | 
             
                    # CUDA ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ
         | 
| 376 | 
             
                    if torch.cuda.is_available():
         | 
|  | |
| 379 | 
             
                    # ์ด๊ธฐ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ํ
์คํธ
         | 
| 380 | 
             
                    try:
         | 
| 381 | 
             
                        test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
         | 
| 382 | 
            +
                        g.trellis_pipeline.preprocess_image(test_image)
         | 
| 383 | 
             
                    except Exception as e:
         | 
| 384 | 
             
                        print(f"Warning: Initial preprocessing test failed: {e}")
         | 
| 385 |  | 
 
			
